Lightning AI Framework

2.6.1 · active · verified Thu Apr 09

Lightning is a deep learning framework built on PyTorch, simplifying the training, deployment, and scaling of AI models. It abstracts away boilerplate code, allowing researchers and engineers to focus on model logic. The current stable version is 2.6.1, and it maintains a rapid release cadence with minor versions typically released every 1-2 months, alongside frequent patch updates.

Warnings

Install

Imports

Quickstart

This quickstart defines a simple linear model using `LightningModule`, prepares dummy data with `DataLoader`, and trains it using the `Trainer`. It showcases the minimal setup for defining a model, training step, optimizer, and running a training loop.

import torch
from torch.utils.data import DataLoader, TensorDataset
from lightning.pytorch import LightningModule, Trainer

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.linear(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.02)
        return optimizer

# 1. Prepare dummy data
x_data = torch.randn(100, 10)
y_data = torch.randn(100, 1)
dataset = TensorDataset(x_data, y_data)
dataloader = DataLoader(dataset, batch_size=32)

# 2. Instantiate model and trainer
model = SimpleModel()
trainer = Trainer(max_epochs=5, enable_progress_bar=False, enable_checkpointing=False)

# 3. Train the model
trainer.fit(model, dataloader)

print("Training complete for a simple Lightning model.")

view raw JSON →