Lightning Fabric
Lightning Fabric is a lightweight, high-performance library for training deep learning models at scale with PyTorch. It provides core utilities for distributed training, mixed-precision, and device management, allowing users to write pure PyTorch code while Fabric handles the boilerplate. It is currently at version 2.6.1 and follows the release cadence of the broader Lightning ecosystem, typically with monthly or bi-monthly patch releases.
Common errors
-
RuntimeError: Expected all tensors to be on the same device, but found at least two devices
cause This typically occurs when tensors are inadvertently created or moved to different devices (e.g., CPU vs. GPU, or different GPUs) without Fabric's orchestration, often in data loading or custom modules.fixEnsure all models, optimizers, and data loaders are passed through `fabric.setup()` and `fabric.setup_dataloader()` respectively. Verify that any custom tensor creation inside your model or dataloader explicitly uses `fabric.device` or is handled by Fabric's setup. -
AttributeError: 'Fabric' object has no attribute 'fit'
cause Users often confuse Lightning Fabric with PyTorch Lightning's Trainer. Fabric does not provide a high-level `fit` method.fixFabric requires you to write the training loop explicitly. It provides tools like `fabric.setup()` and `fabric.backward()` to help, but the loop control is manual. -
ImportError: cannot import name 'Fabric' from 'pytorch_lightning.fabric'
cause The `Fabric` class was moved to the top-level `lightning.fabric` package as part of the Lightning 2.0 refactor. The old import path is no longer valid.fixUpdate your import statement to `from lightning.fabric import Fabric`. -
TypeError: Expected one of device type attribute to be present, but found none.
cause This can happen if Fabric cannot infer the appropriate device (e.g., GPU) or if you're trying to use a specific accelerator without it being properly installed or configured (e.g., trying to use CUDA without PyTorch installed with CUDA support).fixEnsure `torch` is installed with CUDA support if you intend to use GPUs. Explicitly set the `accelerator` and `devices` parameters during `Fabric` initialization (e.g., `Fabric(accelerator='cuda', devices=1)` or `Fabric(accelerator='cpu')`).
Warnings
- breaking Users migrating from PyTorch Lightning 1.x to the Lightning 2.x ecosystem (which includes Fabric) will encounter significant API changes. Fabric adopts a more explicit, less opinionated approach to distributed training and device placement compared to the `Trainer` in PL 1.x.
- gotcha Lightning Fabric provides low-level distributed training primitives. It does not include a full `Trainer` abstraction like `pytorch_lightning.Trainer`. Users are responsible for writing their own training loops, validation loops, and handling callbacks manually.
- gotcha Incorrect device placement (e.g., calling `.cuda()` or `.to(device)` directly on tensors or modules after Fabric has already set them up) can lead to unexpected behavior or errors, especially in distributed environments.
Install
-
pip install lightning-fabric torch
Imports
- Fabric
from pytorch_lightning.fabric import Fabric
from lightning.fabric import Fabric
Quickstart
import os
import torch
from torch.utils.data import DataLoader, TensorDataset
from lightning.fabric import Fabric
# 1. Initialize Fabric
# Configure accelerators, devices, precision. E.g., Fabric(accelerator='cpu', precision='bf16')
# For a basic setup, Fabric() will try to use available GPUs or CPU.
fabric = Fabric(
accelerator=os.environ.get('ACCELERATOR', 'auto'),
devices=os.environ.get('DEVICES', 'auto'),
precision=os.environ.get('PRECISION', '32-true')
)
# 2. Define your model, optimizer, and data
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Generate dummy data
X = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=16)
# 3. Setup the model, optimizer, and data for distributed training
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloader(dataloader)
# 4. Training loop
fabric.print(f"Starting training on device: {fabric.device}")
for epoch in range(3):
for batch_idx, (data, target) in enumerate(dataloader):
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.cross_entropy(output, target)
fabric.backward(loss) # Perform backward pass using Fabric
optimizer.step()
if batch_idx % 10 == 0:
fabric.print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
fabric.print("Training complete!")
# Example of saving: fabric.save("model.pt", {"model": model.state_dict(), "optimizer": optimizer.state_dict()})