TorchEval
TorchEval is a PyTorch library providing a simple interface to create new metrics and an easy-to-use toolkit for metric computations and checkpointing. It offers a rich collection of high-performance metric calculations out-of-the-box, leveraging PyTorch's vectorization and GPU acceleration. Currently at version 0.0.7, it maintains an active release schedule with regular updates and new metric additions.
Common errors
-
TypeError: 'BinaryAccuracy' object is not callable
cause Attempting to call the metric instance directly instead of using its `update()` or `compute()` methods.fixUse `metric.update(predictions, targets)` to feed data and `metric.compute()` to get the result. The metric instance itself is not a function. -
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
cause Input tensors (predictions and targets) passed to `metric.update()` are on different devices (e.g., one on CPU, one on GPU).fixEnsure both `predictions` and `targets` tensors are on the same device as the metric, or on the same device as each other if the metric is device-agnostic (though it often infers from the first input). Use `.to(device)` to move tensors: `predictions.to('cuda'), targets.to('cuda')`. -
ValueError: The 'input' tensor must be 1D for BinaryAccuracy
cause The input prediction tensor provided to `BinaryAccuracy.update()` has an incorrect shape (e.g., `(N, C)` for a multiclass output, or `(N, 1, H, W)`).fixReshape the `input` tensor to be 1-dimensional, typically `(N,)` or `(N,1)` for binary classification. Use methods like `.squeeze()`, `.view()`, or appropriate indexing to get the correct shape. -
Incorrect or unexpected metric values in distributed training.
cause Forgetting to synchronize metric states across distributed processes when computing the final result.fixIn a distributed environment, always use `metric.sync_and_compute()` instead of `metric.compute()` to ensure all local states are gathered and aggregated correctly before calculation.
Warnings
- gotcha In distributed training environments (e.g., using `torch.distributed`), metrics accumulate local states independently on each process. To get the correct global metric value, you must call `metric.sync_and_compute()` instead of `metric.compute()`.
- gotcha Metrics accumulate their internal state across multiple calls to `update()`. If you need to calculate metrics for distinct evaluation periods (e.g., per epoch or per validation run), you must call `metric.reset()` before processing new data, or create a new metric instance.
- gotcha TorchEval is currently in a pre-1.0 state (0.0.x versions). While efforts are made to maintain stability, minor API changes might occur between releases. Always refer to the latest documentation for precise API details.
- gotcha Input tensors for `update()` must adhere to specific shapes and dtypes expected by each metric. For instance, binary metrics typically expect 1D tensors (N,) or (N,1) for predictions and targets.
Install
-
pip install torcheval
Imports
- Metric
from torcheval.metrics import Metric
- BinaryAccuracy
from torcheval.metrics.classification import BinaryAccuracy
from torcheval.metrics import BinaryAccuracy
- MetricCollection
from torcheval.metrics import MetricCollection
Quickstart
import torch
from torcheval.metrics import BinaryAccuracy
# Initialize the metric
metric = BinaryAccuracy()
# Simulate model predictions and ground truth labels
# Ensure inputs are tensors and on the correct device
predictions = torch.tensor([0.9, 0.1, 0.8, 0.2, 0.95])
targets = torch.tensor([1, 0, 1, 0, 1])
# Update the metric with a batch of data
metric.update(predictions, targets)
# Get the computed result
accuracy = metric.compute()
print(f"Binary Accuracy: {accuracy.item():.4f}")
# Example with another batch
predictions2 = torch.tensor([0.4, 0.6, 0.7])
targets2 = torch.tensor([0, 1, 0])
metric.update(predictions2, targets2)
# Compute cumulative accuracy
cumulative_accuracy = metric.compute()
print(f"Cumulative Binary Accuracy: {cumulative_accuracy.item():.4f}")
# Reset the metric's internal state
metric.reset()
print(f"Accuracy after reset and recompute: {metric.compute().item():.4f}")