{"id":7797,"library":"torch-ema","title":"PyTorch EMA (Exponential Moving Average)","description":"torch-ema is a compact PyTorch library designed for efficiently computing and managing exponential moving averages of model parameters during the training of deep learning models. It helps stabilize training and often leads to improved generalization. The current version is 0.3.0, with the last release in November 2021, indicating a slow release cadence.","status":"active","version":"0.3.0","language":"en","source_language":"en","source_url":"https://github.com/fadel/pytorch_ema","tags":["pytorch","ema","deep learning","training utility","exponential moving average"],"install":[{"cmd":"pip install torch-ema","lang":"bash","label":"Install stable version"}],"dependencies":[{"reason":"Core deep learning framework","package":"torch","optional":false}],"imports":[{"symbol":"ExponentialMovingAverage","correct":"from torch_ema import ExponentialMovingAverage"}],"quickstart":{"code":"import torch\nimport torch.nn.functional as F\nfrom torch_ema import ExponentialMovingAverage\n\ntorch.manual_seed(0)\n\nx_train = torch.rand((100, 10))\ny_train = torch.rand(100).round().long()\nx_val = torch.rand((100, 10))\ny_val = torch.rand(100).round().long()\n\nmodel = torch.nn.Linear(10, 2)\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\nema = ExponentialMovingAverage(model.parameters(), decay=0.995)\n\n# Train for a few epochs\nmodel.train()\nfor _ in range(20):\n    logits = model(x_train)\n    loss = F.cross_entropy(logits, y_train)\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n    \n    # Update the moving average with the new parameters\n    ema.update()\n\n# Validation: original model\nmodel.eval()\nwith torch.no_grad():\n    logits_orig = model(x_val)\n    loss_orig = F.cross_entropy(logits_orig, y_val)\n    print(f\"Original model validation loss: {loss_orig.item():.4f}\")\n\n# Validation: with EMA\n# The .average_parameters() context manager:\n# (1) saves original parameters before replacing with EMA version\n# (2) copies EMA parameters to model\n# (3) after exiting the `with`, restores original parameters to resume training later\nwith ema.average_parameters():\n    with torch.no_grad():\n        logits_ema = model(x_val)\n        loss_ema = F.cross_entropy(logits_ema, y_val)\n        print(f\"EMA model validation loss: {loss_ema.item():.4f}\")","lang":"python","description":"Initialize `ExponentialMovingAverage` with your model's parameters and a decay rate. Call `ema.update()` after each optimizer step. For evaluation, use the `ema.average_parameters()` context manager to temporarily swap model weights with their EMA counterparts."},"warnings":[{"fix":"If migrating from <0.3.0, review your parameter handling. If you intended to only track trainable parameters, ensure you filter `model.parameters()` passed to EMA. If you want all parameters tracked, v0.3.0+ handles this by default.","message":"In version 0.3.0, the behavior changed to apply EMA to *all* parameters passed to the `ExponentialMovingAverage` object, regardless of whether they have `requires_grad = True`. Prior versions (e.g., v0.2) would partially ignore parameters without `requires_grad = True`.","severity":"breaking","affected_versions":"<0.3.0 to 0.3.0+"},{"fix":"After `ema.update()` in each process, you must explicitly synchronize the `ema.shadow` parameters across all ranks, typically using `torch.distributed.all_reduce()` on each shadow parameter.","message":"When using `torch-ema` in a distributed training setup (e.g., DDP), the EMA parameters (`ema.shadow`) are not automatically synchronized across GPUs. This requires manual handling.","severity":"gotcha","affected_versions":"All"},{"fix":"Use `ema.state_dict()` and `ema.load_state_dict()` alongside your model and optimizer state_dicts.","message":"The `ExponentialMovingAverage` object's internal state (shadow parameters and update count) must be explicitly saved and loaded when checkpointing your model to resume training correctly.","severity":"gotcha","affected_versions":"All"},{"fix":"If EMA for buffers is required, you would need to implement custom logic to manage them, or consider PyTorch's built-in `torch.optim.swa_utils.AveragedModel` which provides options to handle buffers during SWA/EMA.","message":"By default, `torch-ema` primarily manages model *parameters*. Buffers (e.g., `running_mean`, `running_var` in BatchNorm layers) are not automatically tracked or averaged by `ExponentialMovingAverage`.","severity":"gotcha","affected_versions":"All"}],"env_vars":null,"last_verified":"2026-04-16T00:00:00.000Z","next_check":"2026-07-15T00:00:00.000Z","problems":[{"fix":"If you are on v0.3.0+ and only want to track trainable parameters, ensure you explicitly filter the parameters passed to `ExponentialMovingAverage`: `ema = ExponentialMovingAverage(filter(lambda p: p.requires_grad, model.parameters()), decay=0.995)`. If you are on an older version and want all parameters, upgrade to v0.3.0+.","cause":"Prior to v0.3.0, `torch-ema` would ignore parameters that did not have `requires_grad=True`. If you upgraded to v0.3.0 or later, these parameters are now included, which might change expected behavior.","error":"EMA model does not converge or shows unexpected behavior with non-trainable parameters."},{"fix":"After calling `ema.update()`, iterate through `ema.shadow.items()` and apply `torch.distributed.all_reduce(param, op=torch.distributed.ReduceOp.AVG)` for each `param` in `ema.shadow` to ensure all GPUs have the same averaged EMA weights.","cause":"EMA parameters are not being synchronized across different distributed processes (GPUs). Each GPU is computing its own independent EMA.","error":"My EMA model performs poorly on multi-GPU (DDP) training, even though the base model trains well."},{"fix":"Always save `ema.state_dict()` and load `ema.load_state_dict(checkpoint['ema_state_dict'])` as part of your checkpointing routine, similar to how you handle your model and optimizer.","cause":"The `ExponentialMovingAverage` object's state, including its `shadow` parameters and `update_count`, was not saved or properly loaded when resuming from a checkpoint.","error":"After loading a checkpoint, the EMA model's performance is as if it started from scratch, or worse."}]}