PyTorch Lightning

2.6.1 · active · verified Sun Apr 05

PyTorch Lightning is a lightweight PyTorch wrapper designed to simplify the training and evaluation of deep learning models. It abstracts away common boilerplate code, allowing researchers and engineers to focus on model architecture and experimental logic. The library is actively maintained, currently at version 2.6.1, and follows a release cadence where minor versions may introduce backwards-incompatible changes with deprecations, and major versions may do so without.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates a minimal autoencoder training loop using `lightning`. It covers defining a `LightningModule`, setting up data loaders, and training with the `Trainer`. The code shows how Lightning automatically handles the training loop, backward passes, and optimizer steps, reducing boilerplate. A simple inference step is included to show how to use the trained model.

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L

# 1. Define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# 2. Define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# 3. Define a dataset
dataset = MNIST(os.environ.get('DATASET_PATH', os.getcwd()), download=True, transform=ToTensor())
train_dataloader = utils.data.DataLoader(dataset, batch_size=128)

# 4. Train the model
model = LitAutoEncoder(encoder, decoder)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model, train_dataloader)

# 5. Use the model (optional, example prediction step)
# For inference, set model to eval mode and disable gradients
model.eval()
with Tensor.no_grad():
    sample_input, _ = dataset[0]
    sample_input = sample_input.view(1, -1)
    encoded_output = model.encoder(sample_input)
    decoded_output = model.decoder(encoded_output)
    print(f"Original shape: {sample_input.shape}, Encoded shape: {encoded_output.shape}, Decoded shape: {decoded_output.shape}")

view raw JSON →