Torchtnt

0.2.4 · active · verified Thu Apr 16

Torchtnt is a lightweight library by PyTorch providing training tools and utilities. It is closely integrated with PyTorch and designed for rapid iteration with any model or training regimen. It offers powerful dataloading, logging, and visualization utilities. As of version 0.2.4, it is actively maintained by PyTorch and released as needed. It's currently in a pre-alpha development stage, indicating potential API instability. [9, 13]

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates a basic training loop using Torchtnt's `AutoUnit` and `fit` function. It defines a simple linear model, creates a custom training unit that handles the forward and backward passes, prepares a dummy dataset, and then executes the training. Metrics are logged using `TensorBoardLogger` to the specified log directory. [6]

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchtnt.framework.auto_unit import AutoUnit
from torchtnt.framework.fit import fit
from torchtnt.utils import init_from_env, seed
from torchtnt.utils.loggers import TensorBoardLogger
import logging
import os

logging.basicConfig(level=logging.INFO)

# 1. Define your model
class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

# 2. Define your training unit
class MyTrainingUnit(AutoUnit):
    def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, logger: TensorBoardLogger):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = nn.MSELoss()
        self.logger = logger

    def train_step(self, state: object, data: tuple[torch.Tensor, torch.Tensor]) -> None:
        inputs, targets = data
        outputs = self.model(inputs)
        loss = self.loss_fn(outputs, targets)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.logger.log_scalar("train_loss", loss.item(), step=self.train_progress.num_steps_completed)

# 3. Prepare data
class RandomDataset(Dataset):
    def __init__(self, num_samples, input_dim, output_dim):
        self.data = torch.randn(num_samples, input_dim)
        self.labels = torch.randn(num_samples, output_dim)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

input_dim = 10
output_dim = 1
num_samples = 1000
batch_size = 32
num_epochs = 2

dataset = RandomDataset(num_samples, input_dim, output_dim)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 4. Initialize model, optimizer, and logger
model = SimpleModel(input_dim, output_dim)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

log_dir = os.path.join(os.environ.get('TORCHTNT_LOG_DIR', './runs'), 'my_experiment')
os.makedirs(log_dir, exist_ok=True)
logger = TensorBoardLogger(log_dir)

# 5. Create training unit and run fit
training_unit = MyTrainingUnit(model, optimizer, logger)

print(f"Starting training for {num_epochs} epochs...")
fit(training_unit, train_dataloader=dataloader, max_epochs=num_epochs)
print("Training complete! Check logs in the 'runs' directory.")

view raw JSON →