PyTorch Image Quality (PIQ)
PIQ (PyTorch Image Quality) is a collection of measures and metrics for automatic image quality assessment in image-to-image tasks such as denoising, super-resolution, and image generation. Currently at version 0.8.0, the library offers both functional interfaces for calculating metrics and PyTorch modules for using them as loss functions. It has a regular release cadence, with minor versions released every 1-2 months, continually extending its set of measures and metrics.
Common errors
-
RuntimeError: Expected input to be non-negative, but got values outside this range.
cause Some PIQ metrics or loss functions validate that input tensor values are non-negative by default (e.g., expecting image pixels in [0, 1] or [0, 255]). If your network outputs values outside this range and no appropriate activation (like `sigmoid`) is applied, or `data_range` is incorrect, this error occurs.fixEnsure input tensors are normalized to the `data_range` expected by the metric (e.g., `data_range=1.0` for `[0,1]`, `data_range=255.0` for `[0,255]`). Apply an activation function (like `torch.sigmoid`) to your model's output if it produces values outside the expected positive range and the metric doesn't support negative values via a flag. Some metrics have an `allow_negative=True` flag to permit negative inputs. -
Input tensor shape mismatch. Expected NCHW, got NWHC.
cause PIQ metrics typically expect image tensors in `NCHW` format (Batch, Channels, Height, Width). A common mistake is providing `NHWC` (Batch, Height, Width, Channels) or incorrect spatial dimensions.fixVerify the shape of your input tensors. If your images are `NHWC`, use `tensor.permute(0, 3, 1, 2)` to convert them to `NCHW` before passing to PIQ functions. Also, check specific metric documentation for any minimum height/width requirements (e.g., `multi_scale_gmsd`). -
AttributeError: module 'PhotoSynthesis.Metrics' has no attribute 'ssim'
cause You are attempting to import from `PhotoSynthesis.Metrics`, which was the old package name. The library was renamed to `piq` in version 0.4.1.fixUpdate your import statements to use the new package name: `from piq import ssim` (or other desired metrics). -
RuntimeError: The size of tensor a (256) must match the size of tensor b (512) at non-singleton dimension 2
cause This generic PyTorch error indicates that the dimensions of the two input tensors (e.g., `prediction` and `target` images) do not match, which is a common requirement for full-reference image quality metrics.fixEnsure that the `prediction` and `target` tensors passed to PIQ metrics have identical shapes across all dimensions (batch size, channels, height, width). Resample or crop images if necessary before comparison.
Warnings
- breaking The library underwent a significant rename from `PhotoSynthesis.Metrics` to `piq` in version 0.4.1. Code using old import paths (e.g., `from PhotoSynthesis.Metrics import ssim`) will fail.
- gotcha Backpropagation for the `brisque` metric is not available when using `torch==1.5.0` due to a known bug in PyTorch's `argmin` and `argmax` operations.
- gotcha Many PIQ metrics expect input tensors to have specific value ranges (e.g., non-negative for image pixels) or shapes. Default input validation includes checks like `assert torch.all(tensor >= 0)`, which can raise errors if inputs are outside expected bounds, particularly when using metrics as loss functions without a final activation. Some metrics (since v0.5.4) offer an `allow_negative=True` flag.
- gotcha For performance-critical applications, PIQ's extensive input validation (assertions) can introduce overhead. These checks can be disabled globally.
Install
-
pip install piq -
conda install piq -c photosynthesis-team -c conda-forge -c pytorch
Imports
- ssim
from piq import ssim
- SSIMLoss
from piq import SSIMLoss
- brisque
from piq import brisque
- CLIPIQA
from piq import CLIPIQA
- PhotoSynthesis.Metrics
from PhotoSynthesis.Metrics import ssim
Quickstart
import torch
from piq import ssim, SSIMLoss
# Example tensors (batch_size, channels, height, width)
x = torch.rand(4, 3, 256, 256, requires_grad=True)
y = torch.rand(4, 3, 256, 256)
# 1. Functional interface: Compute SSIM as a measure
ssim_index = ssim(x, y, data_range=1.)
print(f"SSIM index: {ssim_index.item():0.4f}")
# 2. Class interface: Use SSIM as a loss function
loss_fn = SSIMLoss(data_range=1.)
output_loss = loss_fn(x, y)
output_loss.backward() # Backpropagate the loss
print(f"SSIM Loss: {output_loss.item():0.4f}")