Torch-Fidelity: Generative Model Metrics

0.4.0 · active · verified Wed Apr 15

Torch-fidelity is a PyTorch library offering precise, efficient, and extensible implementations of popular generative model evaluation metrics, including Inception Score (ISC), Fréchet Inception Distance (FID), Kernel Inception Distance (KID), Perceptual Path Length (PPL), and Precision and Recall (PRC). It aims for epsilon-exact numerical fidelity with reference TensorFlow implementations. The library is actively maintained, with its latest version being 0.4.0, and has a steady release cadence with significant updates, like new metrics and feature extractors.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to calculate Inception Score (ISC), Fréchet Inception Distance (FID), and Kernel Inception Distance (KID) using `torch-fidelity`'s Python API. It involves defining a dummy generative model, wrapping it with `GenerativeModelModuleWrapper`, and then passing it along with a reference input (like a pre-registered dataset 'cifar10-train') to the `calculate_metrics` function. The results are returned as a dictionary.

import torch
import torch.nn as nn
from torch_fidelity import calculate_metrics
from torch_fidelity.generative_model_module_wrapper import GenerativeModelModuleWrapper

# Dummy generator model for demonstration
class DummyGenerator(nn.Module):
    def __init__(self, z_size, img_size, img_channels):
        super().__init__()
        self.img_size = img_size
        self.img_channels = img_channels
        self.main = nn.Sequential(
            nn.Linear(z_size, 256),
            nn.ReLU(),
            nn.Linear(256, img_channels * img_size * img_size),
            nn.Sigmoid()
        )

    def forward(self, z):
        img = self.main(z)
        return img.view(-1, self.img_channels, self.img_size, self.img_size)

# Configuration
z_size = 128
img_size = 32
img_channels = 3
num_samples = 10000 # Number of samples to generate for metrics
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate and wrap the generator
generator = DummyGenerator(z_size, img_size, img_channels).to(device)
wrapped_generator = GenerativeModelModuleWrapper(
    generator, z_size, 'normal', 0, num_samples=num_samples, samples_batch_size=32
)

# Calculate metrics
# For FID/KID, you need a second input, e.g., a real dataset name or directory path.
# Here, we use a registered input 'cifar10-train' for demonstration.
metrics_dict = calculate_metrics(
    input1=wrapped_generator,
    input2='cifar10-train',
    cuda=True if device.type == 'cuda' else False,
    isc=True,
    fid=True,
    kid=True,
    verbose=False,
    save_cpu_ram=True # Optional: reduce GPU memory if needed
)

print(metrics_dict)

view raw JSON →