Torchtnt
Torchtnt is a lightweight library by PyTorch providing training tools and utilities. It is closely integrated with PyTorch and designed for rapid iteration with any model or training regimen. It offers powerful dataloading, logging, and visualization utilities. As of version 0.2.4, it is actively maintained by PyTorch and released as needed. It's currently in a pre-alpha development stage, indicating potential API instability. [9, 13]
Common errors
-
ModuleNotFoundError: No module named 'torchtnt'
cause The 'torchtnt' package is not installed in the active Python environment, or the environment is not correctly activated.fixInstall the library using `pip install torchtnt` or `conda install -c conda-forge torchtnt`. If using a virtual environment, ensure it is activated. [9] -
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu
cause An operation was attempted between PyTorch tensors or models residing on different compute devices (e.g., one on CPU and another on GPU).fixIdentify all tensors and modules participating in the operation and explicitly move them to the same device using `.to(device)`, where `device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')`. [4, 6] -
RuntimeError: size mismatch, m1: [X, Y], m2: [A, B] (or similar shape errors)
cause The dimensions of input tensors do not align with the expected input dimensions of a layer or an operation, commonly seen in matrix multiplication or feeding data to linear layers.fixPrint the `.shape` attribute of all tensors involved in the problematic operation. Reshape tensors using methods like `.view()`, `.reshape()`, or `.permute().contiguous()` to ensure their dimensions are compatible with the operation or layer. [1, 4, 5]
Warnings
- breaking Torchtnt is currently in '2 - Pre-Alpha' development status according to its PyPI classifiers. This signifies that the API is highly experimental and subject to frequent and significant changes, potentially without strict backward compatibility guarantees between minor or even patch versions.
- gotcha Torchtnt is built on PyTorch, and a proper PyTorch installation is a prerequisite. Issues can arise if PyTorch is not installed correctly or if there are version incompatibilities (especially with CUDA-enabled builds).
- gotcha As Torchtnt operates with PyTorch tensors and modules, common PyTorch runtime errors such as device mismatches, shape mismatches, and datatype errors directly apply. These are frequent sources of frustration for PyTorch developers.
Install
-
pip install torchtnt -
conda install -c conda-forge torchtnt
Imports
- AutoUnit
from torchtnt.framework.auto_unit import AutoUnit
- fit
from torchtnt.framework.fit import fit
- TensorBoardLogger
from torchtnt.utils.loggers import TensorBoardLogger
- init_from_env
from torchtnt.utils import init_from_env
Quickstart
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchtnt.framework.auto_unit import AutoUnit
from torchtnt.framework.fit import fit
from torchtnt.utils import init_from_env, seed
from torchtnt.utils.loggers import TensorBoardLogger
import logging
import os
logging.basicConfig(level=logging.INFO)
# 1. Define your model
class SimpleModel(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
# 2. Define your training unit
class MyTrainingUnit(AutoUnit):
def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, logger: TensorBoardLogger):
super().__init__()
self.model = model
self.optimizer = optimizer
self.loss_fn = nn.MSELoss()
self.logger = logger
def train_step(self, state: object, data: tuple[torch.Tensor, torch.Tensor]) -> None:
inputs, targets = data
outputs = self.model(inputs)
loss = self.loss_fn(outputs, targets)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.logger.log_scalar("train_loss", loss.item(), step=self.train_progress.num_steps_completed)
# 3. Prepare data
class RandomDataset(Dataset):
def __init__(self, num_samples, input_dim, output_dim):
self.data = torch.randn(num_samples, input_dim)
self.labels = torch.randn(num_samples, output_dim)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
input_dim = 10
output_dim = 1
num_samples = 1000
batch_size = 32
num_epochs = 2
dataset = RandomDataset(num_samples, input_dim, output_dim)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 4. Initialize model, optimizer, and logger
model = SimpleModel(input_dim, output_dim)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
log_dir = os.path.join(os.environ.get('TORCHTNT_LOG_DIR', './runs'), 'my_experiment')
os.makedirs(log_dir, exist_ok=True)
logger = TensorBoardLogger(log_dir)
# 5. Create training unit and run fit
training_unit = MyTrainingUnit(model, optimizer, logger)
print(f"Starting training for {num_epochs} epochs...")
fit(training_unit, train_dataloader=dataloader, max_epochs=num_epochs)
print("Training complete! Check logs in the 'runs' directory.")