PyTorch EMA (Exponential Moving Average)

0.3.0 · active · verified Thu Apr 16

torch-ema is a compact PyTorch library designed for efficiently computing and managing exponential moving averages of model parameters during the training of deep learning models. It helps stabilize training and often leads to improved generalization. The current version is 0.3.0, with the last release in November 2021, indicating a slow release cadence.

Common errors

Warnings

Install

Imports

Quickstart

Initialize `ExponentialMovingAverage` with your model's parameters and a decay rate. Call `ema.update()` after each optimizer step. For evaluation, use the `ema.average_parameters()` context manager to temporarily swap model weights with their EMA counterparts.

import torch
import torch.nn.functional as F
from torch_ema import ExponentialMovingAverage

torch.manual_seed(0)

x_train = torch.rand((100, 10))
y_train = torch.rand(100).round().long()
x_val = torch.rand((100, 10))
y_val = torch.rand(100).round().long()

model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
ema = ExponentialMovingAverage(model.parameters(), decay=0.995)

# Train for a few epochs
model.train()
for _ in range(20):
    logits = model(x_train)
    loss = F.cross_entropy(logits, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Update the moving average with the new parameters
    ema.update()

# Validation: original model
model.eval()
with torch.no_grad():
    logits_orig = model(x_val)
    loss_orig = F.cross_entropy(logits_orig, y_val)
    print(f"Original model validation loss: {loss_orig.item():.4f}")

# Validation: with EMA
# The .average_parameters() context manager:
# (1) saves original parameters before replacing with EMA version
# (2) copies EMA parameters to model
# (3) after exiting the `with`, restores original parameters to resume training later
with ema.average_parameters():
    with torch.no_grad():
        logits_ema = model(x_val)
        loss_ema = F.cross_entropy(logits_ema, y_val)
        print(f"EMA model validation loss: {loss_ema.item():.4f}")

view raw JSON →