HeavyBall: Compile-first PyTorch optimizer library

3.0.0 · active · verified Thu Apr 16

HeavyBall is a PyTorch optimizer library that emphasizes 'compile-first' design, assembling optimizers from composable, compiled building blocks. It provides API-compatible replacements for `torch.optim` optimizers like AdamW, SGD, and RMSprop, along with over 30 specialized optimizers such as Muon, SOAP/Shampoo, PSGD, and Schedule-Free. Currently at version 3.0.0, the library is actively maintained with a focus on `torch.compile` fusion, Triton kernel optimization, and memory efficiency, including features like ECC state compression.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use a HeavyBall optimizer, such as `AdamW`, with a simple PyTorch model and a basic training loop. It covers model and data preparation, optimizer initialization, and the standard `zero_grad()`, `backward()`, and `step()` calls. HeavyBall optimizers are designed as drop-in replacements for `torch.optim` classes.

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from heavyball import AdamW # Or any other HeavyBall optimizer

# 1. Define a dummy model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()

# 2. Prepare dummy data
X = torch.randn(100, 10)
y = torch.randn(100, 1)
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=16)

# 3. Initialize the HeavyBall optimizer
optimizer = AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

# 4. Training loop (simplified)
num_epochs = 5
for epoch in range(num_epochs):
    for batch_X, batch_y in dataloader:
        optimizer.zero_grad()
        output = model(batch_X)
        loss = loss_fn(output, batch_y)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

view raw JSON →