PyTorch Metric Learning
PyTorch Metric Learning is a Python library (version 2.9.0) that simplifies the use of deep metric learning in applications. It offers a modular, flexible, and extensible framework built on PyTorch, providing a wide array of loss functions, miners, samplers, trainers, and testers. The library maintains an active release cadence, with frequent updates introducing new features and improvements.
Warnings
- breaking The `emb` argument of `DistributedLossWrapper.forward` was renamed to `embeddings` for consistency with the rest of the library.
- breaking The default value of the `symmetric` flag in `SelfSupervisedLoss` changed from `False` to `True`. If `False`, only `embeddings` are used as anchors. If `True`, `embeddings` and `ref_emb` are both used as anchors.
- gotcha When using `PyTorch's DistributedDataParallel`, `DistributedLossWrapper` and `DistributedMinerWrapper` are essential. Without them, losses and miners in each process will only see a fraction of the global batch, leading to incorrect calculations.
- gotcha Very large batch sizes can lead to `INT_MAX` errors within `loss_and_miner_utils` due to an extremely high number of pairs/triplets being processed.
- gotcha Device mismatches (CPU/GPU) are a common PyTorch error. Ensure your model, input data, and loss/miner functions are all on the same device (e.g., 'cuda') to avoid `RuntimeError: Expected all tensors to be on the same device...`.
Install
-
pip install pytorch-metric-learning -
pip install pytorch-metric-learning[with-hooks]
Imports
- TripletMarginLoss
from pytorch_metric_learning.losses import TripletMarginLoss
- MultiSimilarityMiner
from pytorch_metric_learning.miners import MultiSimilarityMiner
- MetricLossOnly
from pytorch_metric_learning.trainers import MetricLossOnly
- AccuracyCalculator
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
- DistributedLossWrapper
from pytorch_metric_learning.wrappers import DistributedLossWrapper
Quickstart
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pytorch_metric_learning import losses, miners
# 1. Dummy Dataset for demonstration
class DummyDataset(Dataset):
def __init__(self, num_samples=100, embedding_dim=64, num_classes=10):
self.data = torch.randn(num_samples, embedding_dim)
self.labels = torch.randint(0, num_classes, (num_samples,))
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 2. Dummy Model (e.g., identity for pre-computed embeddings)
class DummyModel(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim) # Simple linear layer
def forward(self, x):
return self.linear(x)
# Configuration
embedding_dim = 64
num_classes = 10
batch_size = 32
num_epochs = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize dataset and dataloader
dataset = DummyDataset(embedding_dim=embedding_dim, num_classes=num_classes)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Initialize model, loss, and miner
model = DummyModel(embedding_dim).to(device)
loss_func = losses.TripletMarginLoss(margin=0.1).to(device)
miner = miners.MultiSimilarityMiner(epsilon=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
print(f"Training on {device}...")
for epoch in range(num_epochs):
for i, (data, labels) in enumerate(dataloader):
data, labels = data.to(device), labels.to(device)
optimizer.zero_grad()
embeddings = model(data)
# Mine for hard triplets
hard_triplets = miner(embeddings, labels)
# Compute loss using mined triplets
loss = loss_func(embeddings, labels, hard_triplets)
loss.backward()
optimizer.step()
if i % 10 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Batch {i+1}/{len(dataloader)}, Loss: {loss.item():.4f}")
print("Training complete.")