Lightning Fabric

2.6.1 · active · verified Fri Apr 17

Lightning Fabric is a lightweight, high-performance library for training deep learning models at scale with PyTorch. It provides core utilities for distributed training, mixed-precision, and device management, allowing users to write pure PyTorch code while Fabric handles the boilerplate. It is currently at version 2.6.1 and follows the release cadence of the broader Lightning ecosystem, typically with monthly or bi-monthly patch releases.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use Lightning Fabric to set up a simple PyTorch model for training, leveraging Fabric for device management, mixed precision, and distributed training boilerplate. It covers initialization, model/optimizer/dataloader setup, and a basic training loop with a backward pass handled by Fabric.

import os
import torch
from torch.utils.data import DataLoader, TensorDataset
from lightning.fabric import Fabric

# 1. Initialize Fabric
# Configure accelerators, devices, precision. E.g., Fabric(accelerator='cpu', precision='bf16')
# For a basic setup, Fabric() will try to use available GPUs or CPU.
fabric = Fabric(
    accelerator=os.environ.get('ACCELERATOR', 'auto'),
    devices=os.environ.get('DEVICES', 'auto'),
    precision=os.environ.get('PRECISION', '32-true')
)

# 2. Define your model, optimizer, and data
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Generate dummy data
X = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=16)

# 3. Setup the model, optimizer, and data for distributed training
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloader(dataloader)

# 4. Training loop
fabric.print(f"Starting training on device: {fabric.device}")
for epoch in range(3):
    for batch_idx, (data, target) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(data)
        loss = torch.nn.functional.cross_entropy(output, target)
        fabric.backward(loss) # Perform backward pass using Fabric
        optimizer.step()

        if batch_idx % 10 == 0:
            fabric.print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

fabric.print("Training complete!")

# Example of saving: fabric.save("model.pt", {"model": model.state_dict(), "optimizer": optimizer.state_dict()})

view raw JSON →