FairScale: PyTorch Large-Scale Training Utilities
FairScale is a PyTorch extension library providing utilities for large-scale and high-performance training, including Fully Sharded Data Parallel (FSDP) and Optimizer State Sharding (OSS). While many features, especially FSDP, have been upstreamed to PyTorch, FairScale offers specialized tools for memory and communication efficiency. The current version is 0.4.13. Release cadence is infrequent now, as core functionalities are integrated into PyTorch.
Warnings
- deprecated FairScale's FSDP (`fairscale.nn.data_parallel.FullyShardedDataParallel`) is largely superseded by PyTorch's native FSDP (`torch.distributed.fsdp.FullyShardedDataParallel`) since PyTorch 1.11 and 1.12+. For new projects, the native PyTorch implementation is strongly encouraged due to ongoing development and optimizations.
- breaking FairScale is in maintenance mode, meaning active development for new features has largely shifted to PyTorch's native distributed modules. Future API changes or new features in PyTorch's core distributed components might not be backported or fully compatible with FairScale in the future.
- gotcha FairScale requires a properly initialized `torch.distributed` environment. Running without `dist.init_process_group` (even for single-GPU FSDP) will result in errors or unexpected behavior during model wrapping or training.
- gotcha When using FairScale's FSDP with mixed precision, ensure that the `mixed_precision` argument in `FSDP` is configured correctly, or that you are using a compatible `torch.cuda.amp.GradScaler` outside of FSDP, depending on your PyTorch version and specific setup. Incorrect configuration can lead to performance issues or `NaN` gradients.
Install
-
pip install fairscale
Imports
- FullyShardedDataParallel
from fairscale.nn.data_parallel import FullyShardedDataParallel
- OSS
from fairscale.optim.oss import OSS
Quickstart
import torch
import torch.nn as nn
import torch.distributed as dist
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.optim.oss import OSS
# NOTE: For actual distributed use, dist.init_process_group must be called for multi-GPU/node setups.
# This example simulates a single-process setup for quickstart.
# In a real distributed run, rank and world_size would come from the environment.
# Dummy initialization for single-process quickstart
if not dist.is_initialized():
try:
# Using HashStore for a simple single-node, single-process initialization
dist.init_process_group(backend='gloo', rank=0, world_size=1, store=dist.HashStore())
except RuntimeError as e:
# Catch if already initialized (e.g., in some interactive environments)
print(f"Could not initialize process group (might be already initialized): {e}")
# 1. Define a simple model
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
# 2. Instantiate the model
model = MyModel()
# 3. Wrap the model with FairScale's FSDP
# For simplicity, default options are used. Real-world usage often requires careful tuning.
fsdp_model = FSDP(model)
# 4. Wrap the optimizer with FairScale's OSS
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=1e-3)
oss_optimizer = OSS(params=fsdp_model.parameters(), optim=optimizer)
# 5. Dummy data and training step
input_data = torch.randn(2, 10)
labels = torch.randn(2, 10)
# Forward pass
output = fsdp_model(input_data)
loss = nn.MSELoss()(output, labels)
# Backward pass and optimizer step
oss_optimizer.zero_grad()
loss.backward()
oss_optimizer.step()
print(f"FairScale FSDP and OSS example completed. Loss: {loss.item():.4f}")
# Clean up distributed environment if it was initialized by this script
if dist.is_initialized() and dist.get_world_size() == 1:
dist.destroy_process_group()