torch-scatter

raw JSON →
2.1.2 verified Fri May 01 auth: no python

PyTorch Extension Library of Optimized Scatter Operations. Provides efficient scatter, gather, segment, and softmax operations for PyTorch tensors, commonly used in graph neural networks. Current version 2.1.2, released Oct 2023. Maintained by Rusty1s, part of the PyTorch Geometric ecosystem.

pip install torch-scatter
error ImportError: cannot import name 'scatter_max' from 'torch_scatter'
cause torch-scatter not installed or improperly installed, importing from wrong package.
fix
Install via pip install torch-scatter and ensure the package is imported as from torch_scatter import scatter_max.
error RuntimeError: scatter_max_cuda not implemented for 'torch.LongTensor'
cause Some scatter operations do not support Long tensors (int64) on CUDA; cast to float or int32.
fix
Convert tensor: src = src.float() or use .int() before calling scatter functions.
error AttributeError: module 'torch_scatter' has no attribute 'scatter_logsumexp'
cause Function was removed in version 2.1.0.
fix
Use from torch_scatter import scatter_logsumexp? Actually it was replaced by a different API. Use torch_scatter.scatter_logsumexp after upgrading, or downgrade to 2.0.9.
breaking Version 2.1.0 removed `scatter_logsumexp` and changed `scatter_softmax` to no longer require an `eps` argument. Code using these must update calls.
fix Use `torch_scatter.scatter_logsumexp` (new function) or remove `eps` from `scatter_softmax`.
deprecated `scatter_` functions (e.g., `scatter_add_`) are deprecated in favor of the out-of-place versions. They may be removed in a future release.
fix Replace `scatter_add_(src, index, out)` with `scatter_add(src, index, dim, out=out)`.
gotcha When using CUDA, the installed wheel must match the PyTorch CUDA version exactly. Installing from PyPI may pull a CPU-only build if the CUDA version mismatch is detected.
fix Install using the PyTorch Geometric wheel index: `pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch_version}+{cu_version}.html`.
gotcha `scatter_std` with `unbiased=False` uses Bessel's correction? Actually it uses population std. Check documentation: the flag defaults to `False` which computes the biased standard deviation.
fix Set `unbiased=True` to compute sample standard deviation.
pip install torch-scatter -f https://data.pyg.org/whl/torch-{your_torch_version}+{cu}.html

Compute max over indices using scatter_max.

import torch
from torch_scatter import scatter_max

src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
index = torch.tensor([0, 0, 1, 1, 1])
out, argmax = scatter_max(src, index, dim=0)
print(out)  # tensor([2., 5.])