PyTorch Metric Learning

2.9.0 · active · verified Thu Apr 09

PyTorch Metric Learning is a Python library (version 2.9.0) that simplifies the use of deep metric learning in applications. It offers a modular, flexible, and extensible framework built on PyTorch, providing a wide array of loss functions, miners, samplers, trainers, and testers. The library maintains an active release cadence, with frequent updates introducing new features and improvements.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates a basic training loop using PyTorch Metric Learning. It defines a dummy dataset and model, then initializes a `TripletMarginLoss` and `MultiSimilarityMiner`. The training loop moves data and model to the appropriate device, generates embeddings, mines for hard triplets, computes the loss, and performs backpropagation.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from pytorch_metric_learning import losses, miners

# 1. Dummy Dataset for demonstration
class DummyDataset(Dataset):
    def __init__(self, num_samples=100, embedding_dim=64, num_classes=10):
        self.data = torch.randn(num_samples, embedding_dim)
        self.labels = torch.randint(0, num_classes, (num_samples,))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# 2. Dummy Model (e.g., identity for pre-computed embeddings)
class DummyModel(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.linear = nn.Linear(embedding_dim, embedding_dim) # Simple linear layer

    def forward(self, x):
        return self.linear(x)

# Configuration
embedding_dim = 64
num_classes = 10
batch_size = 32
num_epochs = 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize dataset and dataloader
dataset = DummyDataset(embedding_dim=embedding_dim, num_classes=num_classes)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model, loss, and miner
model = DummyModel(embedding_dim).to(device)
loss_func = losses.TripletMarginLoss(margin=0.1).to(device)
miner = miners.MultiSimilarityMiner(epsilon=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
print(f"Training on {device}...")
for epoch in range(num_epochs):
    for i, (data, labels) in enumerate(dataloader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        embeddings = model(data)

        # Mine for hard triplets
        hard_triplets = miner(embeddings, labels)

        # Compute loss using mined triplets
        loss = loss_func(embeddings, labels, hard_triplets)

        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {i+1}/{len(dataloader)}, Loss: {loss.item():.4f}")

print("Training complete.")

view raw JSON →