LPIPS (Learned Perceptual Image Patch Similarity)
LPIPS is a Python library that implements the Learned Perceptual Image Patch Similarity metric. This metric is designed to measure the similarity between two images in a way that aligns more closely with human perception than traditional metrics like MSE or SSIM. It leverages deep features extracted from pre-trained convolutional neural networks (like AlexNet, VGG, or SqueezeNet). The library is currently at version 0.1.4 and is primarily maintained through its GitHub repository, with releases tied to significant updates. It's often used in image generation and restoration tasks to evaluate perceptual quality or as a perceptual loss function.
Warnings
- breaking A bug in the initial 'v0.0' release (before `v0.1.x`) caused inputs not to be scaled, leading to different results compared to the paper. This was fixed in `v0.1` and later versions where linear scaling is enabled by default.
- gotcha Input images for LPIPS must be 3-channel RGB PyTorch Tensors and normalized to the range `[-1, 1]`. Incorrect normalization or channel dimensions (e.g., `[0, 1]` range or grayscale) will lead to incorrect or unexpected similarity scores.
- gotcha The default network `net='alex'` is optimized for best *forward* scores (evaluating similarity). For use as a 'perceptual loss' in optimization/backpropagation, `net='vgg'` is often recommended as it is closer to traditional perceptual loss functions.
- gotcha LPIPS models, being based on deep neural networks, can be susceptible to adversarial attacks, meaning small, imperceptible perturbations can significantly alter the LPIPS score, leading to humanly similar images being judged as very different by the metric. Variants like E-LPIPS or R-LPIPS address this but are not part of the core `lpips` package.
- gotcha Running LPIPS, especially with the 'vgg' backbone, can consume significant GPU memory. This can lead to out-of-memory errors with larger batch sizes or higher resolution images.
Install
-
pip install lpips
Imports
- LPIPS
import lpips loss_fn = lpips.LPIPS()
Quickstart
import torch
import lpips
# Ensure PyTorch is set up (e.g., for GPU if available)
# LPIPS does not typically require API keys for model loading
# Initialize the LPIPS model, using AlexNet as the default backbone
# 'alex' is recommended for best forward scores, 'vgg' for perceptual loss in optimization
loss_fn_alex = lpips.LPIPS(net='alex')
# Create two dummy images (batch_size, channels, height, width)
# IMPORTANT: Images should be RGB (3 channels) and normalized to [-1, 1]
img0 = torch.rand(1, 3, 64, 64) * 2 - 1 # Random image 1, normalized to [-1, 1]
img1 = torch.rand(1, 3, 64, 64) * 2 - 1 # Random image 2, normalized to [-1, 1]
# Compute the LPIPS distance
d = loss_fn_alex(img0, img1)
print(f"LPIPS distance: {d.item():.4f}")
# Example with VGG network, often preferred for 'perceptual loss' in training
loss_fn_vgg = lpips.LPIPS(net='vgg')
d_vgg = loss_fn_vgg(img0, img1)
print(f"LPIPS distance (VGG): {d_vgg.item():.4f}")