Lightning AI Framework
Lightning is a deep learning framework built on PyTorch, simplifying the training, deployment, and scaling of AI models. It abstracts away boilerplate code, allowing researchers and engineers to focus on model logic. The current stable version is 2.6.1, and it maintains a rapid release cadence with minor versions typically released every 1-2 months, alongside frequent patch updates.
Warnings
- breaking The primary package name for PyTorch-specific components was renamed from `pytorch_lightning` to `lightning.pytorch` in v2.0. Direct imports from the old name will fail.
- breaking The return signature for `LightningModule.configure_optimizers()` changed. For a single optimizer, it should now return just the optimizer instance directly, not a list containing a single optimizer.
- gotcha Lightning automatically handles device placement for models, data, and optimizers. Manually calling `.to(device)` on models or tensors within `training_step` or similar methods is usually unnecessary and can lead to bugs or redundant operations.
- deprecated The `to_torchscript` method on `LightningModule` has been deprecated.
- gotcha The `LightningCLI` command-line interface changed its execution pattern. Instead of `python your_script.py fit`, it now uses `lightning run model your_script.py`.
Install
-
pip install lightning -
pip install 'lightning[pytorch]' # for PyTorch-specific dependencies
Imports
- LightningModule
from lightning.pytorch import LightningModule
- Trainer
from lightning.pytorch import Trainer
- ModelCheckpoint
from lightning.pytorch.callbacks import ModelCheckpoint
- LightningDataModule
from lightning.pytorch.utilities.data import LightningDataModule
Quickstart
import torch
from torch.utils.data import DataLoader, TensorDataset
from lightning.pytorch import LightningModule, Trainer
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 1)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.linear(x)
loss = torch.nn.functional.mse_loss(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.02)
return optimizer
# 1. Prepare dummy data
x_data = torch.randn(100, 10)
y_data = torch.randn(100, 1)
dataset = TensorDataset(x_data, y_data)
dataloader = DataLoader(dataset, batch_size=32)
# 2. Instantiate model and trainer
model = SimpleModel()
trainer = Trainer(max_epochs=5, enable_progress_bar=False, enable_checkpointing=False)
# 3. Train the model
trainer.fit(model, dataloader)
print("Training complete for a simple Lightning model.")