HeavyBall: Compile-first PyTorch optimizer library
HeavyBall is a PyTorch optimizer library that emphasizes 'compile-first' design, assembling optimizers from composable, compiled building blocks. It provides API-compatible replacements for `torch.optim` optimizers like AdamW, SGD, and RMSprop, along with over 30 specialized optimizers such as Muon, SOAP/Shampoo, PSGD, and Schedule-Free. Currently at version 3.0.0, the library is actively maintained with a focus on `torch.compile` fusion, Triton kernel optimization, and memory efficiency, including features like ECC state compression.
Common errors
-
AttributeError: module 'heavyball' has no attribute 'ForeachAdamW'
cause Attempting to use an optimizer name with the `Foreach*` prefix (e.g., `ForeachAdamW`) after upgrading to HeavyBall v3.0.0 or later, where these prefixes were removed.fixUpdate your import statements and optimizer instantiations to use the simplified, shorter class names. For `ForeachAdamW`, use `from heavyball import AdamW`. -
RuntimeError: Error(s) in loading state_dict for SimpleModel: Unexpected key(s) in state_dict: "optimizer_states.0.state.step".
cause Loading a model or optimizer checkpoint saved with an older version of HeavyBall (e.g., v1.x or v2.x) into a newer version (v2.0.0+ or v3.0.0+) without applying necessary migration steps due to changes in internal state representation.fixFor checkpoints saved with HeavyBall v1.x, use the `scripts/migrate_optimizer_state.py` utility provided in the repository. For v2.x checkpoints, consult the v3.0.0 migration guide for specific conversion steps if any are needed. -
Optimizer step produces significantly different or worse convergence compared to `torch.optim` or Optax.
cause The default division backend used by HeavyBall (`eps_clamp`) for calculating adaptive learning rates or update norms differs from the `eps_add` method commonly used in `torch.optim` and Optax, leading to numerical discrepancies.fixTo match the standard behavior, set the division backend globally before initializing optimizers: `import heavyball.utils; heavyball.utils.default_division_backend = "eps_add"`. -
ModuleNotFoundError: No module named 'heavyball.optimizers'
cause Incorrect import path for optimizers. HeavyBall optimizers are typically available directly under the `heavyball` namespace, not a nested `heavyball.optimizers` module.fixChange your import statement from `from heavyball.optimizers import AdamW` to `from heavyball import AdamW`.
Warnings
- breaking HeavyBall v3.0.0 removed `Foreach*` prefixes from optimizer class names (e.g., `ForeachAdamW` is now `AdamW`). Code relying on the old naming convention will break.
- breaking HeavyBall v2.2.0 introduced changes to the SOAP optimizer infrastructure. Custom SOAP variants created for earlier versions may not work out-of-the-box.
- gotcha HeavyBall's default division backend (`eps_clamp`) differs from the industry standard (`eps_add`) used by PyTorch and Optax, potentially leading to meaningfully different numerical results if not accounted for.
- gotcha When using ECC (Error Correction Code) with `torch.compile`, earlier versions (pre-v2.3.1) could experience `torch.compile` fusing away crucial ECC math, leading to incorrect results, particularly with stochastic rounding.
- breaking HeavyBall v2.0.0 introduced significant numerical stability improvements, SVD computation accuracy, and a reworked chainable backend, impacting checkpointing. Optimizer checkpoints saved with HeavyBall v1.x are not directly compatible.
- gotcha HeavyBall optimizers, by default, consume gradients during `step()` and clear `p.grad`. If your training loop requires gradients to remain attached after the optimizer step (e.g., for gradient accumulation or logging), they will be cleared.
Install
-
pip install heavyball
Imports
- AdamW
from heavyball import ForeachAdamW
from heavyball import AdamW
- SOAP
from heavyball import SOAP
- Muon
from heavyball import Muon
Quickstart
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from heavyball import AdamW # Or any other HeavyBall optimizer
# 1. Define a dummy model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
model = SimpleModel()
# 2. Prepare dummy data
X = torch.randn(100, 10)
y = torch.randn(100, 1)
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=16)
# 3. Initialize the HeavyBall optimizer
optimizer = AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
# 4. Training loop (simplified)
num_epochs = 5
for epoch in range(num_epochs):
for batch_X, batch_y in dataloader:
optimizer.zero_grad()
output = model(batch_X)
loss = loss_fn(output, batch_y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")