EMA PyTorch
ema-pytorch is a Python library that provides an easy way to integrate Exponential Moving Average (EMA) into PyTorch models. It helps stabilize training and improve generalization by maintaining a smoothed version of model parameters over time. The library is actively developed, with frequent updates, and is currently at version 0.7.9.
Common errors
-
Error: While trying to deepcopy model: {e} Your model was not copyable. Please make sure you are not using any LazyLinearcause The EMA wrapper attempts to deepcopy the model, and certain PyTorch modules like `torch.nn.LazyLinear` cannot be deepcopied until their input shape is determined.fixEnsure that your model's layers are fully initialized with concrete input/output shapes before passing it to `EMA`. Avoid `LazyLinear` if you intend to deepcopy, or ensure it's evaluated once before EMA initialization. The `PostHocEMA` source code specifically mentions `LazyLinear` as an issue. -
KeyError: 'ema_model' not found when loading checkpoint
cause You might be trying to load only the `state_dict` of the EMA model (`ema.ema_model.state_dict()`) into a new `EMA` wrapper, or you're trying to load a general `state_dict` without the specific keys `ema-pytorch` expects.fixAs recommended, save the entire `ema` object using `torch.save(ema, 'path/to/ema.pth')`. When loading, use `ema = torch.load('path/to/ema.pth')`. If you must save `state_dict`, ensure all relevant components are saved (model, ema_model, and the EMA wrapper's internal state) and loaded carefully. -
AttributeError: 'EMA' object has no attribute 'train' or 'eval'
cause You are trying to call `.train()` or `.eval()` directly on the `ema` wrapper object, but these methods should be called on the *wrapped* model, or the `ema.ema_model` directly.fixTo set the underlying EMA model to evaluation mode, use `ema.ema_model.eval()`. Similarly, for training mode (though `ema_model` is typically for inference), use `ema.ema_model.train()` or ensure the main `net` is in `train()` mode for updates. The `ema` wrapper itself handles its state.
Warnings
- gotcha When using EMA with optimizers that employ weight decay (e.g., AdamW), there can be interference. EMA tracks raw weights, not their decayed counterparts, which might lead to the EMA not fully accounting for the optimizer's weight decay impact.
- gotcha For Mixed Precision Training (e.g., using `torch.cuda.amp.autocast`), EMA updates should occur *outside* the `autocast` context to prevent numerical instabilities and precision issues.
- gotcha Extended training can sometimes destabilize EMA decay, especially with decay rates close to 1, potentially leading to 'oversmoothing'.
- gotcha To correctly save and load the EMA state, it's recommended to save the entire EMA wrapper object, not just `ema.ema_model.state_dict()`, as the wrapper contains crucial state like the number of steps taken (for warmup logic).
Install
-
pip install ema-pytorch
Imports
- EMA
from ema_pytorch import EMA
- PostHocEMA
from ema_pytorch import PostHocEMA
Quickstart
import torch
from ema_pytorch import EMA
# Your neural network as a PyTorch module
net = torch.nn.Linear(512, 512)
# Wrap your neural network with EMA
ema = EMA(
net,
beta = 0.9999, # exponential moving average factor
update_after_step = 100, # only after this number of .update() calls will it start updating
update_every = 10 # how often to actually update, to save on compute
)
# Simulate training steps
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
for step in range(1000):
optimizer.zero_grad()
data = torch.randn(1, 512)
target = torch.randn(1, 512)
output = net(data)
loss = torch.nn.functional.mse_loss(output, target)
loss.backward()
optimizer.step()
# Update the EMA model
ema.update()
# Later, for inference, use the EMA model
with torch.no_grad():
data_inference = torch.randn(1, 512)
ema_output = ema(data_inference)
print(f"EMA model output shape: {ema_output.shape}")