EMA PyTorch

0.7.9 · active · verified Thu Apr 16

ema-pytorch is a Python library that provides an easy way to integrate Exponential Moving Average (EMA) into PyTorch models. It helps stabilize training and improve generalization by maintaining a smoothed version of model parameters over time. The library is actively developed, with frequent updates, and is currently at version 0.7.9.

Common errors

Warnings

Install

Imports

Quickstart

Initialize your PyTorch model, then wrap it with the `EMA` class, specifying the decay factor (`beta`). During your training loop, call `ema.update()` after `optimizer.step()` to update the EMA parameters. For inference or validation, you can directly call the `ema` object, which will use the averaged parameters.

import torch
from ema_pytorch import EMA

# Your neural network as a PyTorch module
net = torch.nn.Linear(512, 512)

# Wrap your neural network with EMA
ema = EMA(
    net,
    beta = 0.9999,  # exponential moving average factor
    update_after_step = 100, # only after this number of .update() calls will it start updating
    update_every = 10  # how often to actually update, to save on compute
)

# Simulate training steps
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
for step in range(1000):
    optimizer.zero_grad()
    data = torch.randn(1, 512)
    target = torch.randn(1, 512)
    output = net(data)
    loss = torch.nn.functional.mse_loss(output, target)
    loss.backward()
    optimizer.step()
    
    # Update the EMA model
    ema.update()

# Later, for inference, use the EMA model
with torch.no_grad():
    data_inference = torch.randn(1, 512)
    ema_output = ema(data_inference)
    print(f"EMA model output shape: {ema_output.shape}")

view raw JSON →