PyTorch EMA (Exponential Moving Average)
torch-ema is a compact PyTorch library designed for efficiently computing and managing exponential moving averages of model parameters during the training of deep learning models. It helps stabilize training and often leads to improved generalization. The current version is 0.3.0, with the last release in November 2021, indicating a slow release cadence.
Common errors
-
EMA model does not converge or shows unexpected behavior with non-trainable parameters.
cause Prior to v0.3.0, `torch-ema` would ignore parameters that did not have `requires_grad=True`. If you upgraded to v0.3.0 or later, these parameters are now included, which might change expected behavior.fixIf you are on v0.3.0+ and only want to track trainable parameters, ensure you explicitly filter the parameters passed to `ExponentialMovingAverage`: `ema = ExponentialMovingAverage(filter(lambda p: p.requires_grad, model.parameters()), decay=0.995)`. If you are on an older version and want all parameters, upgrade to v0.3.0+. -
My EMA model performs poorly on multi-GPU (DDP) training, even though the base model trains well.
cause EMA parameters are not being synchronized across different distributed processes (GPUs). Each GPU is computing its own independent EMA.fixAfter calling `ema.update()`, iterate through `ema.shadow.items()` and apply `torch.distributed.all_reduce(param, op=torch.distributed.ReduceOp.AVG)` for each `param` in `ema.shadow` to ensure all GPUs have the same averaged EMA weights. -
After loading a checkpoint, the EMA model's performance is as if it started from scratch, or worse.
cause The `ExponentialMovingAverage` object's state, including its `shadow` parameters and `update_count`, was not saved or properly loaded when resuming from a checkpoint.fixAlways save `ema.state_dict()` and load `ema.load_state_dict(checkpoint['ema_state_dict'])` as part of your checkpointing routine, similar to how you handle your model and optimizer.
Warnings
- breaking In version 0.3.0, the behavior changed to apply EMA to *all* parameters passed to the `ExponentialMovingAverage` object, regardless of whether they have `requires_grad = True`. Prior versions (e.g., v0.2) would partially ignore parameters without `requires_grad = True`.
- gotcha When using `torch-ema` in a distributed training setup (e.g., DDP), the EMA parameters (`ema.shadow`) are not automatically synchronized across GPUs. This requires manual handling.
- gotcha The `ExponentialMovingAverage` object's internal state (shadow parameters and update count) must be explicitly saved and loaded when checkpointing your model to resume training correctly.
- gotcha By default, `torch-ema` primarily manages model *parameters*. Buffers (e.g., `running_mean`, `running_var` in BatchNorm layers) are not automatically tracked or averaged by `ExponentialMovingAverage`.
Install
-
pip install torch-ema
Imports
- ExponentialMovingAverage
from torch_ema import ExponentialMovingAverage
Quickstart
import torch
import torch.nn.functional as F
from torch_ema import ExponentialMovingAverage
torch.manual_seed(0)
x_train = torch.rand((100, 10))
y_train = torch.rand(100).round().long()
x_val = torch.rand((100, 10))
y_val = torch.rand(100).round().long()
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
# Train for a few epochs
model.train()
for _ in range(20):
logits = model(x_train)
loss = F.cross_entropy(logits, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update the moving average with the new parameters
ema.update()
# Validation: original model
model.eval()
with torch.no_grad():
logits_orig = model(x_val)
loss_orig = F.cross_entropy(logits_orig, y_val)
print(f"Original model validation loss: {loss_orig.item():.4f}")
# Validation: with EMA
# The .average_parameters() context manager:
# (1) saves original parameters before replacing with EMA version
# (2) copies EMA parameters to model
# (3) after exiting the `with`, restores original parameters to resume training later
with ema.average_parameters():
with torch.no_grad():
logits_ema = model(x_val)
loss_ema = F.cross_entropy(logits_ema, y_val)
print(f"EMA model validation loss: {loss_ema.item():.4f}")