PyTorch-Ignite

0.5.4 · active · verified Thu Apr 16

PyTorch-Ignite is a lightweight and user-friendly library designed to simplify training and evaluating neural networks with PyTorch. It provides a high-level API for setting up training loops, handling events, and integrating various experiment tracking tools. Currently at version 0.5.4, it maintains an active release cadence with frequent bug fixes and feature enhancements.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates setting up a basic training loop with PyTorch-Ignite. It defines a simple PyTorch model, creates a trainer and evaluator using `create_supervised_trainer` and `create_supervised_evaluator`, attaches a handler to log results after each epoch, and runs the training process. The example includes dummy data for immediate execution.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

# 1. Define a simple model, optimizer, loss function
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)
    def forward(self, x):
        return self.fc(x)

model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# 2. Create dummy data
X = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=10)

# 3. Create trainer and evaluator
trainer = create_supervised_trainer(model, optimizer, criterion)
evaluator = create_supervised_evaluator(model, criterion, metrics={'accuracy': Accuracy(), 'nll': Loss(criterion)})

# 4. Define handlers for events
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    evaluator.run(dataloader)
    metrics = evaluator.state.metrics
    print(f"Epoch {engine.state.epoch}/{engine.state.max_epochs} - Avg accuracy: {metrics['accuracy']:.2f}, Avg loss: {metrics['nll']:.2f}")

# 5. Run the training
trainer.run(dataloader, max_epochs=2)

print("\nTraining complete.")

view raw JSON →