Running Statistics for PyTorch
torch-runstats provides efficient running/online statistics (mean, standard deviation, variance, count) for PyTorch tensors. It's designed for scenarios where data arrives sequentially or cannot be stored in its entirety. The current version is 0.2.0, and its release cadence is slow, suggesting a mature and stable library for its specific functionality.
Common errors
-
RuntimeError: The size of tensor a (10) must match the size of tensor b (3) at non-singleton dimension 1
cause The `shape` parameter provided during `RunningStats` or `RunningMeanStd` initialization does not match the trailing dimensions of the input tensor being passed to `update()`.fixInitialize the `RunningStats` or `RunningMeanStd` object with `shape` matching the last dimension(s) of your input data. For a tensor `x` of shape `(batch, ..., feature_dim)`, use `RunningStats(shape=(feature_dim,))`. -
UserWarning: std is NaN due to insufficient data.
cause This warning occurs when attempting to retrieve the standard deviation (`.std`) before at least two data points have been accumulated by `update()`.fixEnsure that your `RunningStats` or `RunningMeanStd` instance has received at least two valid data points via `update()` calls before accessing `.std`. If `count < 2`, `std` is undefined or zero. -
AttributeError: 'RunningStats' object has no attribute 'some_method_from_torch_scatter'
cause You are trying to access functionality that was implicitly available through `torch_scatter` when it was a dependency of `torch-runstats` prior to version 0.2.0.fixIf you need `torch_scatter` for other parts of your code, install it explicitly (`pip install torch-scatter`). `torch-runstats` itself no longer relies on it from v0.2.0 onwards.
Warnings
- gotcha The `shape` parameter in `RunningMeanStd` and `RunningStats` initialization is crucial. It defines the shape of the *feature vector* for which statistics are computed, not the batch dimension. Incorrect `shape` leads to dimension mismatch errors during `update`.
- gotcha By default, from version 0.2.0, `mask_nan=True` for `RunningStats`, meaning `NaN` values in the input tensor are ignored when computing statistics and count. This might change behavior for users upgrading from v0.1.0 or expecting `NaN` to propagate.
- gotcha Standard deviation (`std`) or variance can be zero or `NaN` if `RunningStats` or `RunningMeanStd` has not accumulated at least two distinct data points. Accessing `std` too early will result in `NaN` or `0`.
- breaking The dependency on `torch_scatter` was removed in version 0.2.0. While this primarily impacts internal implementation and reduces install size, users who might have indirectly relied on `torch_scatter` being present due to `torch-runstats` might find it missing.
Install
-
pip install torch-runstats
Imports
- RunningMeanStd
from torch_runstats import RunningMeanStd
- RunningStats
from torch_runstats import RunningStats
Quickstart
import torch
from torch_runstats import RunningMeanStd, RunningStats
# Example with RunningMeanStd
# Initialize for a feature vector of size 3
rms = RunningMeanStd(shape=(3,))
# Simulate incoming data
x1 = torch.randn(10, 3)
x2 = torch.randn(5, 3)
rms.update(x1)
rms.update(x2)
print(f"Running Mean: {rms.mean}")
print(f"Running Std Dev: {rms.std}")
# Example with RunningStats (more general, includes variance and count)
rs = RunningStats(shape=(2,))
y1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y2 = torch.tensor([[5.0, 6.0], [float('nan'), 8.0]]) # Demonstrating NaN masking
rs.update(y1)
rs.update(y2)
print(f"Running Stats Mean: {rs.mean}")
print(f"Running Stats Std Dev: {rs.std}")
print(f"Running Stats Count: {rs.count}") # NaN in y2 is ignored by default