PyTorch Lightning
PyTorch Lightning is a lightweight PyTorch wrapper designed to simplify the training and evaluation of deep learning models. It abstracts away common boilerplate code, allowing researchers and engineers to focus on model architecture and experimental logic. The library is actively maintained, currently at version 2.6.1, and follows a release cadence where minor versions may introduce backwards-incompatible changes with deprecations, and major versions may do so without.
Warnings
- breaking Major API and package renaming in version 2.0. The primary package name for installation changed from `pytorch-lightning` to `lightning`, and imports moved from `pytorch_lightning` (e.g., `pytorch_lightning.Trainer`) to `lightning` (e.g., `lightning.Trainer`). Additionally, many `Trainer` arguments, such as `gpus`, `tpus`, etc., were deprecated in 1.x and removed/refactored in 2.0 in favor of accelerator configurations (e.g., `accelerator='gpu', devices=4`).
- deprecated The `to_torchscript` method on `LightningModule` was deprecated in version 2.6.1.
- gotcha Manual device placement (e.g., `.cuda()`, `.to(device)`) is generally not needed within a `LightningModule` and can cause issues. Lightning's `Trainer` handles device management automatically.
- gotcha For distributed training, `DistributedSampler` is automatically applied to `DataLoader`s by the `Trainer` when a distributed strategy is used. Manually wrapping your `DataLoader` with `DistributedSampler` can lead to incorrect behavior or errors.
Install
-
pip install pytorch-lightning -
pip install lightning
Imports
- LightningModule
from lightning import LightningModule
- Trainer
from lightning import Trainer
Quickstart
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
# 1. Define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
# 2. Define the LightningModule
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# 3. Define a dataset
dataset = MNIST(os.environ.get('DATASET_PATH', os.getcwd()), download=True, transform=ToTensor())
train_dataloader = utils.data.DataLoader(dataset, batch_size=128)
# 4. Train the model
model = LitAutoEncoder(encoder, decoder)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model, train_dataloader)
# 5. Use the model (optional, example prediction step)
# For inference, set model to eval mode and disable gradients
model.eval()
with Tensor.no_grad():
sample_input, _ = dataset[0]
sample_input = sample_input.view(1, -1)
encoded_output = model.encoder(sample_input)
decoded_output = model.decoder(encoded_output)
print(f"Original shape: {sample_input.shape}, Encoded shape: {encoded_output.shape}, Decoded shape: {decoded_output.shape}")