Lion Optimizer for PyTorch

0.2.4 · active · verified Fri Apr 17

lion-pytorch provides an efficient and high-performance implementation of the Lion optimizer for PyTorch. Based on the paper 'Symbolic Discovery of Optimization Algorithms', Lion often outperforms AdamW and other adaptive optimizers, especially in large-scale models, due to its sign-based update mechanism. The library is actively maintained, currently at version 0.2.4, and requires Python 3.9+.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize a PyTorch model and the Lion optimizer, then perform a single forward and backward pass to update the model parameters. It highlights the typical workflow for integrating Lion into a PyTorch training loop.

import torch
from torch import nn
from lion_pytorch import Lion

# 1. Define a simple PyTorch model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)
        self.relu = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        return self.output(self.relu(self.linear(x)))

model = SimpleModel()

# 2. Instantiate the Lion optimizer
# Note: Lion often requires a smaller learning rate than AdamW (e.g., 1e-4)
optimizer = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)

# 3. Create dummy data and target
inputs = torch.randn(32, 10) # 32 samples, 10 features
targets = torch.randn(32, 1) # 32 samples, 1 target value

# 4. Define a loss function
criterion = nn.MSELoss()

# 5. Training loop (one step for quickstart demonstration)
optimizer.zero_grad()           # Clear gradients from previous step
outputs = model(inputs)         # Forward pass
loss = criterion(outputs, targets) # Calculate loss
loss.backward()                 # Backward pass (compute gradients)
optimizer.step()                # Update model parameters

print(f"Loss after one optimization step: {loss.item():.4f}")

view raw JSON →