TorchMetrics

1.9.0 · active · verified Mon Apr 06

TorchMetrics is a comprehensive collection of PyTorch native metrics for evaluating machine learning models, offering over 100 common and specialized metrics implemented directly in PyTorch. Developed and maintained by Lightning AI, it provides a standardized, rigorously tested, and distributed-training compatible API for metric computation, reducing boilerplate and ensuring reproducibility. It automatically accumulates over batches and synchronizes between multiple devices. The library is currently at version 1.9.0 and maintains a regular release cadence with several patch and minor releases per year.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the core ways to use TorchMetrics: the functional API for stateless, single-batch computation, the class-based API for accumulating states over multiple batches, and MetricCollection for grouping several metrics. Remember to reset class-based metrics after each epoch or evaluation phase to avoid mixing states.

import torch
import torchmetrics
from torchmetrics import Accuracy, MetricCollection
from torchmetrics.functional import accuracy

# 1. Functional API: For single-batch, stateless computation
preds_f = torch.randn(10, 5).softmax(dim=-1)
target_f = torch.randint(5, (10,))
acc_functional = accuracy(preds_f, target_f, task="multiclass", num_classes=5)
print(f"Functional Accuracy: {acc_functional.item()}")

# 2. Class-based API: For accumulating metrics over multiple batches/epochs
metric = Accuracy(task="multiclass", num_classes=5)
preds_c = torch.randn(10, 5).softmax(dim=-1)
target_c = torch.randint(5, (10,))
metric.update(preds_c, target_c)

# Simulate another batch
preds_c2 = torch.randn(10, 5).softmax(dim=-1)
target_c2 = torch.randint(5, (10,))
metric.update(preds_c2, target_c2)

final_acc = metric.compute()
print(f"Class-based Accuracy (accumulated): {final_acc.item()}")
metric.reset() # Reset metric states for the next epoch/evaluation

# 3. MetricCollection: Group multiple metrics
metrics = MetricCollection({
    'Accuracy': Accuracy(task="multiclass", num_classes=5),
    'F1Score': torchmetrics.F1Score(task="multiclass", num_classes=5)
})
preds_mc = torch.randn(10, 5).softmax(dim=-1)
target_mc = torch.randint(5, (10,))
metrics.update(preds_mc, target_mc)
result_mc = metrics.compute()
print(f"MetricCollection Result: {result_mc}")

view raw JSON →