Schedule-Free Optimization in PyTorch

1.4.1 · active · verified Thu Apr 16

Schedule-Free is a PyTorch library that provides optimizers designed for 'schedule-free' learning, eliminating the need for traditional learning rate schedules. It aims to achieve faster training times without requiring users to specify the stopping time or steps in advance. The library, currently at version 1.4.1, offers variants of popular optimizers like SGD, AdamW, and RAdam, and is actively maintained by Facebook Research.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to integrate `AdamWScheduleFree` into a basic PyTorch training loop. Key steps include initializing the optimizer, performing forward/backward passes, and crucially, calling `optimizer.train()` and `optimizer.eval()` alongside `model.train()` and `model.eval()` for correct parameter buffer handling during training and evaluation/checkpointing.

import torch
import torch.nn as nn
from schedulefree import AdamWScheduleFree

# 1. Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()

# 2. Define a Schedule-Free optimizer (e.g., AdamWScheduleFree)
# Note: Schedule-Free optimizers often benefit from higher learning rates than traditional ones.
optimizer = AdamWScheduleFree(model.parameters(), lr=1e-3, warmup_steps=100)

# 3. Define a loss function
criterion = nn.MSELoss()

# 4. Dummy data for demonstration
inputs = torch.randn(64, 10)
targets = torch.randn(64, 1)

# 5. Training loop (simplified)
num_epochs = 10
for epoch in range(num_epochs):
    model.train() # Standard PyTorch model training mode
    optimizer.train() # REQUIRED for Schedule-Free optimizers

    # Forward pass
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 2 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

    # Evaluation phase (e.g., for validation or checkpointing)
    model.eval() # Standard PyTorch model evaluation mode
    optimizer.eval() # REQUIRED for Schedule-Free optimizers before evaluation/checkpointing
    with torch.no_grad():
        val_inputs = torch.randn(16, 10)
        val_targets = torch.randn(16, 1)
        val_outputs = model(val_inputs)
        val_loss = criterion(val_outputs, val_targets)
        # In a real scenario, you'd calculate metrics on val_outputs and val_targets

print("Training complete.")

view raw JSON →