{"id":7204,"library":"ema-pytorch","title":"EMA PyTorch","description":"ema-pytorch is a Python library that provides an easy way to integrate Exponential Moving Average (EMA) into PyTorch models. It helps stabilize training and improve generalization by maintaining a smoothed version of model parameters over time. The library is actively developed, with frequent updates, and is currently at version 0.7.9.","status":"active","version":"0.7.9","language":"en","source_language":"en","source_url":"https://github.com/lucidrains/ema-pytorch","tags":["pytorch","deep learning","exponential moving average","machine learning","training","computer vision","nlp"],"install":[{"cmd":"pip install ema-pytorch","lang":"bash","label":"Install stable version"}],"dependencies":[{"reason":"Core deep learning framework for model operations.","package":"torch","optional":false}],"imports":[{"note":"Primary class for standard Exponential Moving Average.","symbol":"EMA","correct":"from ema_pytorch import EMA"},{"note":"Class for post-hoc EMA synthesis, proposed by Karras et al.","symbol":"PostHocEMA","correct":"from ema_pytorch import PostHocEMA"}],"quickstart":{"code":"import torch\nfrom ema_pytorch import EMA\n\n# Your neural network as a PyTorch module\nnet = torch.nn.Linear(512, 512)\n\n# Wrap your neural network with EMA\nema = EMA(\n    net,\n    beta = 0.9999,  # exponential moving average factor\n    update_after_step = 100, # only after this number of .update() calls will it start updating\n    update_every = 10  # how often to actually update, to save on compute\n)\n\n# Simulate training steps\noptimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\nfor step in range(1000):\n    optimizer.zero_grad()\n    data = torch.randn(1, 512)\n    target = torch.randn(1, 512)\n    output = net(data)\n    loss = torch.nn.functional.mse_loss(output, target)\n    loss.backward()\n    optimizer.step()\n    \n    # Update the EMA model\n    ema.update()\n\n# Later, for inference, use the EMA model\nwith torch.no_grad():\n    data_inference = torch.randn(1, 512)\n    ema_output = ema(data_inference)\n    print(f\"EMA model output shape: {ema_output.shape}\")","lang":"python","description":"Initialize your PyTorch model, then wrap it with the `EMA` class, specifying the decay factor (`beta`). During your training loop, call `ema.update()` after `optimizer.step()` to update the EMA parameters. For inference or validation, you can directly call the `ema` object, which will use the averaged parameters."},"warnings":[{"fix":"Consider using a lower weight decay rate when EMA is active, or apply EMA updates only to the optimizer's non-decayed weights through a custom schedule.","message":"When using EMA with optimizers that employ weight decay (e.g., AdamW), there can be interference. EMA tracks raw weights, not their decayed counterparts, which might lead to the EMA not fully accounting for the optimizer's weight decay impact.","severity":"gotcha","affected_versions":"All"},{"fix":"Ensure `ema.update()` is called after `scaler.update()` and outside the `with autocast():` block.","message":"For Mixed Precision Training (e.g., using `torch.cuda.amp.autocast`), EMA updates should occur *outside* the `autocast` context to prevent numerical instabilities and precision issues.","severity":"gotcha","affected_versions":"All"},{"fix":"Experiment with resetting EMA decay to a lower rate after certain epochs and gradually increasing it, or implementing a decay rate schedule where it increases as the model converges (e.g., from 0.99 to 0.999 over epochs).","message":"Extended training can sometimes destabilize EMA decay, especially with decay rates close to 1, potentially leading to 'oversmoothing'.","severity":"gotcha","affected_versions":"All"},{"fix":"Use `torch.save(ema, 'ema_model.pth')` and `ema = torch.load('ema_model.pth')`. If saving only the state_dict, ensure you also save `ema.num_updates` and `ema.beta` (and potentially `update_after_step`).","message":"To correctly save and load the EMA state, it's recommended to save the entire EMA wrapper object, not just `ema.ema_model.state_dict()`, as the wrapper contains crucial state like the number of steps taken (for warmup logic).","severity":"gotcha","affected_versions":"All"}],"env_vars":null,"last_verified":"2026-04-16T00:00:00.000Z","next_check":"2026-07-15T00:00:00.000Z","problems":[{"fix":"Ensure that your model's layers are fully initialized with concrete input/output shapes before passing it to `EMA`. Avoid `LazyLinear` if you intend to deepcopy, or ensure it's evaluated once before EMA initialization. The `PostHocEMA` source code specifically mentions `LazyLinear` as an issue.","cause":"The EMA wrapper attempts to deepcopy the model, and certain PyTorch modules like `torch.nn.LazyLinear` cannot be deepcopied until their input shape is determined.","error":"Error: While trying to deepcopy model: {e} Your model was not copyable. Please make sure you are not using any LazyLinear"},{"fix":"As recommended, save the entire `ema` object using `torch.save(ema, 'path/to/ema.pth')`. When loading, use `ema = torch.load('path/to/ema.pth')`. If you must save `state_dict`, ensure all relevant components are saved (model, ema_model, and the EMA wrapper's internal state) and loaded carefully.","cause":"You might be trying to load only the `state_dict` of the EMA model (`ema.ema_model.state_dict()`) into a new `EMA` wrapper, or you're trying to load a general `state_dict` without the specific keys `ema-pytorch` expects.","error":"KeyError: 'ema_model' not found when loading checkpoint"},{"fix":"To set the underlying EMA model to evaluation mode, use `ema.ema_model.eval()`. Similarly, for training mode (though `ema_model` is typically for inference), use `ema.ema_model.train()` or ensure the main `net` is in `train()` mode for updates. The `ema` wrapper itself handles its state.","cause":"You are trying to call `.train()` or `.eval()` directly on the `ema` wrapper object, but these methods should be called on the *wrapped* model, or the `ema.ema_model` directly.","error":"AttributeError: 'EMA' object has no attribute 'train' or 'eval'"}]}