Running Statistics for PyTorch

0.2.0 · active · verified Thu Apr 16

torch-runstats provides efficient running/online statistics (mean, standard deviation, variance, count) for PyTorch tensors. It's designed for scenarios where data arrives sequentially or cannot be stored in its entirety. The current version is 0.2.0, and its release cadence is slow, suggesting a mature and stable library for its specific functionality.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize and use `RunningMeanStd` and `RunningStats` to track statistics for streaming data. It highlights the `shape` parameter for multi-dimensional data and implicitly shows `mask_nan=True` (default in v0.2.0) functionality for `RunningStats`.

import torch
from torch_runstats import RunningMeanStd, RunningStats

# Example with RunningMeanStd
# Initialize for a feature vector of size 3
rms = RunningMeanStd(shape=(3,))

# Simulate incoming data
x1 = torch.randn(10, 3)
x2 = torch.randn(5, 3)

rms.update(x1)
rms.update(x2)

print(f"Running Mean: {rms.mean}")
print(f"Running Std Dev: {rms.std}")

# Example with RunningStats (more general, includes variance and count)
rs = RunningStats(shape=(2,))
y1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y2 = torch.tensor([[5.0, 6.0], [float('nan'), 8.0]]) # Demonstrating NaN masking

rs.update(y1)
rs.update(y2)

print(f"Running Stats Mean: {rs.mean}")
print(f"Running Stats Std Dev: {rs.std}")
print(f"Running Stats Count: {rs.count}") # NaN in y2 is ignored by default

view raw JSON →