FairScale: PyTorch Large-Scale Training Utilities

0.4.13 · maintenance · verified Mon Apr 13

FairScale is a PyTorch extension library providing utilities for large-scale and high-performance training, including Fully Sharded Data Parallel (FSDP) and Optimizer State Sharding (OSS). While many features, especially FSDP, have been upstreamed to PyTorch, FairScale offers specialized tools for memory and communication efficiency. The current version is 0.4.13. Release cadence is infrequent now, as core functionalities are integrated into PyTorch.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to wrap a PyTorch model with FairScale's Fully Sharded Data Parallel (FSDP) and its Optimizer State Sharding (OSS) for memory-efficient training. Note that `dist.init_process_group` is essential for multi-GPU/node training; a dummy initialization is used here for a runnable single-process example. For new projects, it is highly recommended to consider migrating to PyTorch's native FSDP.

import torch
import torch.nn as nn
import torch.distributed as dist
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.optim.oss import OSS

# NOTE: For actual distributed use, dist.init_process_group must be called for multi-GPU/node setups.
# This example simulates a single-process setup for quickstart.
# In a real distributed run, rank and world_size would come from the environment.

# Dummy initialization for single-process quickstart
if not dist.is_initialized():
    try:
        # Using HashStore for a simple single-node, single-process initialization
        dist.init_process_group(backend='gloo', rank=0, world_size=1, store=dist.HashStore())
    except RuntimeError as e:
        # Catch if already initialized (e.g., in some interactive environments)
        print(f"Could not initialize process group (might be already initialized): {e}")

# 1. Define a simple model
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 10)
    def forward(self, x):
        return self.layer(x)

# 2. Instantiate the model
model = MyModel()

# 3. Wrap the model with FairScale's FSDP
# For simplicity, default options are used. Real-world usage often requires careful tuning.
fsdp_model = FSDP(model)

# 4. Wrap the optimizer with FairScale's OSS
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=1e-3)
oss_optimizer = OSS(params=fsdp_model.parameters(), optim=optimizer)

# 5. Dummy data and training step
input_data = torch.randn(2, 10)
labels = torch.randn(2, 10)

# Forward pass
output = fsdp_model(input_data)
loss = nn.MSELoss()(output, labels)

# Backward pass and optimizer step
oss_optimizer.zero_grad()
loss.backward()
oss_optimizer.step()

print(f"FairScale FSDP and OSS example completed. Loss: {loss.item():.4f}")

# Clean up distributed environment if it was initialized by this script
if dist.is_initialized() and dist.get_world_size() == 1:
    dist.destroy_process_group()

view raw JSON →