Schedule-Free Optimization in PyTorch
Schedule-Free is a PyTorch library that provides optimizers designed for 'schedule-free' learning, eliminating the need for traditional learning rate schedules. It aims to achieve faster training times without requiring users to specify the stopping time or steps in advance. The library, currently at version 1.4.1, offers variants of popular optimizers like SGD, AdamW, and RAdam, and is actively maintained by Facebook Research.
Common errors
-
Exception: Optimizer was not in train mode when step is called. optimizer.train() must be called before optimizer.step(). See documentation for details.
cause The `optimizer.train()` method was not called before `optimizer.step()` during the training loop.fixAdd `optimizer.train()` at the beginning of your training epoch or step, similar to how `model.train()` is used. -
Incorrect or inconsistent evaluation results, especially with models using BatchNorm layers.
cause Batch normalization statistics during evaluation are being computed from the intermediate `y` sequence of the optimizer instead of the final `x` sequence, leading to discrepancies.fixImplement explicit handling for BatchNorm layers during evaluation to ensure statistics are correctly updated from the `x` sequence, or use `PreciseBN` if applicable. -
Suboptimal convergence or performance compared to traditional optimizers with well-tuned learning rate schedules.
cause Despite being 'schedule-free', the library still requires careful tuning of other hyperparameters like the initial learning rate and regularization, and potentially the `beta` parameter. Starting with default values without adjustment may yield poor results.fixTune the learning rate (often higher than traditional optimizers, e.g., 10x-50x for SGD, 1x-10x for AdamW) and regularization parameters. Consider increasing the `beta` value for very long training runs (e.g., to 0.95 or 0.98).
Warnings
- gotcha Schedule-Free optimizers require explicit calls to `optimizer.train()` and `optimizer.eval()` to manage internal parameter buffers correctly during training and evaluation phases, respectively. Forgetting these calls can lead to incorrect updates or runtime errors.
- breaking In version 1.3, the behavior of weight decay during learning rate warmup was changed to improve stability and consistency with standard `AdamW` in PyTorch.
- gotcha If your model utilizes BatchNorm layers, additional modifications are necessary for `test/val` evaluations to function correctly. This is because batch statistics need to be computed from the `x` sequence, not the `y` sequence.
- gotcha Training with Schedule-Free optimizers can be more sensitive to the choice of the `beta` parameter than with standard momentum. While the default `0.9` works for many problems, increasing it to `0.95` or `0.98` might be necessary for very long training runs to achieve optimal performance.
- gotcha The optimal learning rates for Schedule-Free optimizers are typically higher than those used with schedule-based approaches. For SGD, a learning rate 10x-50x larger might be a good starting point, while for AdamW, 1x-10x larger rates are often effective.
Install
-
pip install schedulefree
Imports
- SGDScheduleFree
from schedulefree import SGDScheduleFree
- AdamWScheduleFree
from schedulefree import AdamWScheduleFree
- RAdamScheduleFree
from schedulefree import RAdamScheduleFree
- ScheduleFreeWrapper
from schedulefree import ScheduleFreeWrapper
Quickstart
import torch
import torch.nn as nn
from schedulefree import AdamWScheduleFree
# 1. Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
model = SimpleModel()
# 2. Define a Schedule-Free optimizer (e.g., AdamWScheduleFree)
# Note: Schedule-Free optimizers often benefit from higher learning rates than traditional ones.
optimizer = AdamWScheduleFree(model.parameters(), lr=1e-3, warmup_steps=100)
# 3. Define a loss function
criterion = nn.MSELoss()
# 4. Dummy data for demonstration
inputs = torch.randn(64, 10)
targets = torch.randn(64, 1)
# 5. Training loop (simplified)
num_epochs = 10
for epoch in range(num_epochs):
model.train() # Standard PyTorch model training mode
optimizer.train() # REQUIRED for Schedule-Free optimizers
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 2 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
# Evaluation phase (e.g., for validation or checkpointing)
model.eval() # Standard PyTorch model evaluation mode
optimizer.eval() # REQUIRED for Schedule-Free optimizers before evaluation/checkpointing
with torch.no_grad():
val_inputs = torch.randn(16, 10)
val_targets = torch.randn(16, 1)
val_outputs = model(val_inputs)
val_loss = criterion(val_outputs, val_targets)
# In a real scenario, you'd calculate metrics on val_outputs and val_targets
print("Training complete.")