TorchMetrics
TorchMetrics is a comprehensive collection of PyTorch native metrics for evaluating machine learning models, offering over 100 common and specialized metrics implemented directly in PyTorch. Developed and maintained by Lightning AI, it provides a standardized, rigorously tested, and distributed-training compatible API for metric computation, reducing boilerplate and ensuring reproducibility. It automatically accumulates over batches and synchronizes between multiple devices. The library is currently at version 1.9.0 and maintains a regular release cadence with several patch and minor releases per year.
Warnings
- breaking Python 3.9 support has been dropped with the release of v1.9.0. The minimum required Python version is now 3.10.
- breaking The default value for the `average` argument in `DiceScore` has changed from `None` to `"macro"` starting from v1.9.0. This can alter the behavior of existing code if the `average` argument was not explicitly set.
- gotcha Metrics maintain internal states that accumulate data. Mixing these states across different phases (e.g., training, validation, testing) or re-using the same metric instance without resetting can lead to incorrect results or memory leaks.
- gotcha Metric states are initialized on the CPU. When working with PyTorch tensors on GPU, especially in distributed training (DDP), ensure that metric objects are moved to the same device as the input data using `.to(device)`. Failure to do so can result in `RuntimeError: Encountered different devices in metric calculation`.
- gotcha When defining metrics as part of a `torch.nn.Module` or `LightningModule`, avoid using native Python `list` or `dict` to store multiple `Metric` instances. These will not be correctly identified as child modules, preventing automatic device placement and state management.
- gotcha Users of `MetricCollection` might encounter `UserWarning: The compute method of metric X was called before the update method...` This often indicates an issue where internal states of grouped metrics are not being updated correctly before `compute` is called, particularly in older versions or specific usage patterns.
- gotcha For performance-critical applications, especially with `MeanMetric`, explicitly providing a weight tensor to `update` instead of relying on default values can be beneficial. Additionally, disabling NaN checks in the base `Aggregator` class or careful device management can reduce overhead.
Install
-
pip install torchmetrics
Imports
- Accuracy
from torchmetrics import Accuracy
- functional.accuracy
from torchmetrics.functional import accuracy
- MetricCollection
from torchmetrics import MetricCollection
- Metric
from torchmetrics import Metric
Quickstart
import torch
import torchmetrics
from torchmetrics import Accuracy, MetricCollection
from torchmetrics.functional import accuracy
# 1. Functional API: For single-batch, stateless computation
preds_f = torch.randn(10, 5).softmax(dim=-1)
target_f = torch.randint(5, (10,))
acc_functional = accuracy(preds_f, target_f, task="multiclass", num_classes=5)
print(f"Functional Accuracy: {acc_functional.item()}")
# 2. Class-based API: For accumulating metrics over multiple batches/epochs
metric = Accuracy(task="multiclass", num_classes=5)
preds_c = torch.randn(10, 5).softmax(dim=-1)
target_c = torch.randint(5, (10,))
metric.update(preds_c, target_c)
# Simulate another batch
preds_c2 = torch.randn(10, 5).softmax(dim=-1)
target_c2 = torch.randint(5, (10,))
metric.update(preds_c2, target_c2)
final_acc = metric.compute()
print(f"Class-based Accuracy (accumulated): {final_acc.item()}")
metric.reset() # Reset metric states for the next epoch/evaluation
# 3. MetricCollection: Group multiple metrics
metrics = MetricCollection({
'Accuracy': Accuracy(task="multiclass", num_classes=5),
'F1Score': torchmetrics.F1Score(task="multiclass", num_classes=5)
})
preds_mc = torch.randn(10, 5).softmax(dim=-1)
target_mc = torch.randint(5, (10,))
metrics.update(preds_mc, target_mc)
result_mc = metrics.compute()
print(f"MetricCollection Result: {result_mc}")