{"library":"torchmetrics","title":"TorchMetrics","description":"TorchMetrics is a comprehensive collection of PyTorch native metrics for evaluating machine learning models, offering over 100 common and specialized metrics implemented directly in PyTorch. Developed and maintained by Lightning AI, it provides a standardized, rigorously tested, and distributed-training compatible API for metric computation, reducing boilerplate and ensuring reproducibility. It automatically accumulates over batches and synchronizes between multiple devices. The library is currently at version 1.9.0 and maintains a regular release cadence with several patch and minor releases per year.","status":"active","version":"1.9.0","language":"en","source_language":"en","source_url":"https://github.com/Lightning-AI/torchmetrics","tags":["machine learning","pytorch","metrics","deep learning","evaluation"],"install":[{"cmd":"pip install torchmetrics","lang":"bash","label":"PyPI"}],"dependencies":[{"reason":"Core dependency for all metric computations.","package":"torch","optional":false},{"reason":"Requires Python 3.10 or newer.","package":"python","optional":false}],"imports":[{"symbol":"Accuracy","correct":"from torchmetrics import Accuracy"},{"symbol":"functional.accuracy","correct":"from torchmetrics.functional import accuracy"},{"symbol":"MetricCollection","correct":"from torchmetrics import MetricCollection"},{"note":"Base class for implementing custom metrics.","symbol":"Metric","correct":"from torchmetrics import Metric"}],"quickstart":{"code":"import torch\nimport torchmetrics\nfrom torchmetrics import Accuracy, MetricCollection\nfrom torchmetrics.functional import accuracy\n\n# 1. Functional API: For single-batch, stateless computation\npreds_f = torch.randn(10, 5).softmax(dim=-1)\ntarget_f = torch.randint(5, (10,))\nacc_functional = accuracy(preds_f, target_f, task=\"multiclass\", num_classes=5)\nprint(f\"Functional Accuracy: {acc_functional.item()}\")\n\n# 2. Class-based API: For accumulating metrics over multiple batches/epochs\nmetric = Accuracy(task=\"multiclass\", num_classes=5)\npreds_c = torch.randn(10, 5).softmax(dim=-1)\ntarget_c = torch.randint(5, (10,))\nmetric.update(preds_c, target_c)\n\n# Simulate another batch\npreds_c2 = torch.randn(10, 5).softmax(dim=-1)\ntarget_c2 = torch.randint(5, (10,))\nmetric.update(preds_c2, target_c2)\n\nfinal_acc = metric.compute()\nprint(f\"Class-based Accuracy (accumulated): {final_acc.item()}\")\nmetric.reset() # Reset metric states for the next epoch/evaluation\n\n# 3. MetricCollection: Group multiple metrics\nmetrics = MetricCollection({\n    'Accuracy': Accuracy(task=\"multiclass\", num_classes=5),\n    'F1Score': torchmetrics.F1Score(task=\"multiclass\", num_classes=5)\n})\npreds_mc = torch.randn(10, 5).softmax(dim=-1)\ntarget_mc = torch.randint(5, (10,))\nmetrics.update(preds_mc, target_mc)\nresult_mc = metrics.compute()\nprint(f\"MetricCollection Result: {result_mc}\")","lang":"python","description":"This quickstart demonstrates the core ways to use TorchMetrics: the functional API for stateless, single-batch computation, the class-based API for accumulating states over multiple batches, and MetricCollection for grouping several metrics. Remember to reset class-based metrics after each epoch or evaluation phase to avoid mixing states."},"warnings":[{"fix":"Upgrade your Python environment to 3.10 or newer, or pin `torchmetrics<1.9.0`.","message":"Python 3.9 support has been dropped with the release of v1.9.0. The minimum required Python version is now 3.10.","severity":"breaking","affected_versions":">=1.9.0"},{"fix":"Explicitly set the `average` argument in `DiceScore` to `None` or your desired reduction method if you relied on the previous default behavior.","message":"The default value for the `average` argument in `DiceScore` has changed from `None` to `\"macro\"` starting from v1.9.0. This can alter the behavior of existing code if the `average` argument was not explicitly set.","severity":"breaking","affected_versions":">=1.9.0"},{"fix":"Always initialize separate metric instances for different phases (training, validation, test) or call `metric.reset()` after each complete evaluation epoch/phase to clear its internal state.","message":"Metrics maintain internal states that accumulate data. Mixing these states across different phases (e.g., training, validation, testing) or re-using the same metric instance without resetting can lead to incorrect results or memory leaks.","severity":"gotcha","affected_versions":"All"},{"fix":"Call `metric.to(device)` after initialization, or ensure the metric is registered as a child module within a `torch.nn.Module` or `LightningModule`, which handles device transfers automatically.","message":"Metric states are initialized on the CPU. When working with PyTorch tensors on GPU, especially in distributed training (DDP), ensure that metric objects are moved to the same device as the input data using `.to(device)`. Failure to do so can result in `RuntimeError: Encountered different devices in metric calculation`.","severity":"gotcha","affected_versions":"All"},{"fix":"Use `torch.nn.ModuleList` or `torch.nn.ModuleDict` instead of native Python collections when nesting metrics within a `torch.nn.Module`.","message":"When defining metrics as part of a `torch.nn.Module` or `LightningModule`, avoid using native Python `list` or `dict` to store multiple `Metric` instances. These will not be correctly identified as child modules, preventing automatic device placement and state management.","severity":"gotcha","affected_versions":"All"},{"fix":"Ensure all metrics within the `MetricCollection` receive `update` calls for the relevant data. If issues persist, consider isolating metrics or upgrading to the latest `torchmetrics` version as device management and state handling are continually improved.","message":"Users of `MetricCollection` might encounter `UserWarning: The compute method of metric X was called before the update method...` This often indicates an issue where internal states of grouped metrics are not being updated correctly before `compute` is called, particularly in older versions or specific usage patterns.","severity":"gotcha","affected_versions":"<=1.8.x"},{"fix":"Profile your code. For `MeanMetric`, consider `metric.update(value, weight=my_tensor)`. For advanced optimization, explore `Aggregator` configurations and carefully manage device placement and cross-device synchronization.","message":"For performance-critical applications, especially with `MeanMetric`, explicitly providing a weight tensor to `update` instead of relying on default values can be beneficial. Additionally, disabling NaN checks in the base `Aggregator` class or careful device management can reduce overhead.","severity":"gotcha","affected_versions":"All"}],"env_vars":null,"last_verified":"2026-04-06T00:00:00.000Z","next_check":"2026-07-05T00:00:00.000Z"}