{"id":9033,"library":"heavyball","title":"HeavyBall: Compile-first PyTorch optimizer library","description":"HeavyBall is a PyTorch optimizer library that emphasizes 'compile-first' design, assembling optimizers from composable, compiled building blocks. It provides API-compatible replacements for `torch.optim` optimizers like AdamW, SGD, and RMSprop, along with over 30 specialized optimizers such as Muon, SOAP/Shampoo, PSGD, and Schedule-Free. Currently at version 3.0.0, the library is actively maintained with a focus on `torch.compile` fusion, Triton kernel optimization, and memory efficiency, including features like ECC state compression.","status":"active","version":"3.0.0","language":"en","source_language":"en","source_url":"https://github.com/HomebrewML/HeavyBall","tags":["pytorch","optimizer","deep-learning","machine-learning","torch-compile","triton","mixed-precision","memory-efficient","adamw","sgd","soap","psgd","muon"],"install":[{"cmd":"pip install heavyball","lang":"bash","label":"Install latest version"}],"dependencies":[{"reason":"HeavyBall is a PyTorch optimizer library and requires PyTorch >= 2.2 for optimal functionality, especially with `torch.compile` features.","package":"torch"}],"imports":[{"note":"In HeavyBall v3.0.0, `Foreach*` prefixes were removed from optimizer class names to simplify the public API; use the short, canonical names instead.","wrong":"from heavyball import ForeachAdamW","symbol":"AdamW","correct":"from heavyball import AdamW"},{"symbol":"SOAP","correct":"from heavyball import SOAP"},{"symbol":"Muon","correct":"from heavyball import Muon"}],"quickstart":{"code":"import torch\nfrom torch import nn\nfrom torch.utils.data import DataLoader, TensorDataset\nfrom heavyball import AdamW # Or any other HeavyBall optimizer\n\n# 1. Define a dummy model\nclass SimpleModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear = nn.Linear(10, 1)\n\n    def forward(self, x):\n        return self.linear(x)\n\nmodel = SimpleModel()\n\n# 2. Prepare dummy data\nX = torch.randn(100, 10)\ny = torch.randn(100, 1)\ndataset = TensorDataset(X, y)\ndataloader = DataLoader(dataset, batch_size=16)\n\n# 3. Initialize the HeavyBall optimizer\noptimizer = AdamW(model.parameters(), lr=1e-3)\nloss_fn = nn.MSELoss()\n\n# 4. Training loop (simplified)\nnum_epochs = 5\nfor epoch in range(num_epochs):\n    for batch_X, batch_y in dataloader:\n        optimizer.zero_grad()\n        output = model(batch_X)\n        loss = loss_fn(output, batch_y)\n        loss.backward()\n        optimizer.step()\n    print(f\"Epoch {epoch+1}, Loss: {loss.item():.4f}\")","lang":"python","description":"This quickstart demonstrates how to use a HeavyBall optimizer, such as `AdamW`, with a simple PyTorch model and a basic training loop. It covers model and data preparation, optimizer initialization, and the standard `zero_grad()`, `backward()`, and `step()` calls. HeavyBall optimizers are designed as drop-in replacements for `torch.optim` classes."},"warnings":[{"fix":"Update optimizer imports and instantiations to use the new, shorter class names (e.g., `from heavyball import AdamW`).","message":"HeavyBall v3.0.0 removed `Foreach*` prefixes from optimizer class names (e.g., `ForeachAdamW` is now `AdamW`). Code relying on the old naming convention will break.","severity":"breaking","affected_versions":">=3.0.0"},{"fix":"Refer to the v2.2.0 release notes for guidance on converting existing SOAP configurations to the new infrastructure. Updates are often trivial.","message":"HeavyBall v2.2.0 introduced changes to the SOAP optimizer infrastructure. Custom SOAP variants created for earlier versions may not work out-of-the-box.","severity":"breaking","affected_versions":">=2.2.0"},{"fix":"To align with standard behavior, set `heavyball.utils.default_division_backend = \"eps_add\"` early in your script. Other options like `atan2` are also available.","message":"HeavyBall's default division backend (`eps_clamp`) differs from the industry standard (`eps_add`) used by PyTorch and Optax, potentially leading to meaningfully different numerical results if not accounted for.","severity":"gotcha","affected_versions":">=2.2.1"},{"fix":"Upgrade to HeavyBall v2.3.1 or later to leverage the internal fixes (manual bit arithmetic). If using older versions, avoid ECC with stochastic rounding.","message":"When using ECC (Error Correction Code) with `torch.compile`, earlier versions (pre-v2.3.1) could experience `torch.compile` fusing away crucial ECC math, leading to incorrect results, particularly with stochastic rounding.","severity":"gotcha","affected_versions":"<2.3.1"},{"fix":"Use the provided `scripts/migrate_optimizer_state.py` utility to convert pre-2.0 optimizer checkpoints. Consult the v2.0.0 and v3.0.0 migration guides for detailed instructions.","message":"HeavyBall v2.0.0 introduced significant numerical stability improvements, SVD computation accuracy, and a reworked chainable backend, impacting checkpointing. Optimizer checkpoints saved with HeavyBall v1.x are not directly compatible.","severity":"breaking","affected_versions":">=2.0.0 (from 1.x)"},{"fix":"Initialize your optimizer with `consume_grad=False` (e.g., `AdamW(model.parameters(), lr=1e-3, consume_grad=False)`) to prevent gradients from being cleared automatically.","message":"HeavyBall optimizers, by default, consume gradients during `step()` and clear `p.grad`. If your training loop requires gradients to remain attached after the optimizer step (e.g., for gradient accumulation or logging), they will be cleared.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-16T00:00:00.000Z","next_check":"2026-07-15T00:00:00.000Z","problems":[{"fix":"Update your import statements and optimizer instantiations to use the simplified, shorter class names. For `ForeachAdamW`, use `from heavyball import AdamW`.","cause":"Attempting to use an optimizer name with the `Foreach*` prefix (e.g., `ForeachAdamW`) after upgrading to HeavyBall v3.0.0 or later, where these prefixes were removed.","error":"AttributeError: module 'heavyball' has no attribute 'ForeachAdamW'"},{"fix":"For checkpoints saved with HeavyBall v1.x, use the `scripts/migrate_optimizer_state.py` utility provided in the repository. For v2.x checkpoints, consult the v3.0.0 migration guide for specific conversion steps if any are needed.","cause":"Loading a model or optimizer checkpoint saved with an older version of HeavyBall (e.g., v1.x or v2.x) into a newer version (v2.0.0+ or v3.0.0+) without applying necessary migration steps due to changes in internal state representation.","error":"RuntimeError: Error(s) in loading state_dict for SimpleModel: Unexpected key(s) in state_dict: \"optimizer_states.0.state.step\"."},{"fix":"To match the standard behavior, set the division backend globally before initializing optimizers: `import heavyball.utils; heavyball.utils.default_division_backend = \"eps_add\"`.","cause":"The default division backend used by HeavyBall (`eps_clamp`) for calculating adaptive learning rates or update norms differs from the `eps_add` method commonly used in `torch.optim` and Optax, leading to numerical discrepancies.","error":"Optimizer step produces significantly different or worse convergence compared to `torch.optim` or Optax."},{"fix":"Change your import statement from `from heavyball.optimizers import AdamW` to `from heavyball import AdamW`.","cause":"Incorrect import path for optimizers. HeavyBall optimizers are typically available directly under the `heavyball` namespace, not a nested `heavyball.optimizers` module.","error":"ModuleNotFoundError: No module named 'heavyball.optimizers'"}]}