PyTorch MS-SSIM
pytorch-msssim provides a fast and differentiable implementation of Multi-Scale Structural Similarity (MS-SSIM) and Structural Similarity (SSIM) index for PyTorch. It is designed to be efficient by using separable Gaussian kernels. The library is currently at version 1.0.0, with releases occurring as needed rather than on a strict schedule.
Warnings
- gotcha Input images must be non-negative and denormalized to the expected `data_range`. If your images are normalized (e.g., -1 to 1), you must denormalize them to a range like [0, 1] or [0, 255] before passing them to `ssim` or `ms_ssim`.
- gotcha The `nonnegative_ssim=True` parameter is recommended for SSIM to prevent negative output values, though it defaults to `False` for consistency with other implementations like TensorFlow/scikit-image. For MS-SSIM, intermediate SSIM responses are internally forced to be non-negative to avoid NaN results.
- gotcha Input tensors `X` and `Y` must be 4-dimensional `(N, C, H, W)` (batch_size, channels, height, width). Passing single 3D images (e.g., `(C, H, W)`) will lead to errors.
- gotcha When using MS-SSIM as a loss function for training, especially with unstable models, setting the `normalize` parameter (e.g., `normalize='relu'`) in the `MS_SSIM` module can significantly improve training stability and help avoid NaN results. This `normalize` option is adapted from a different implementation to enhance training robustness.
Install
-
pip install pytorch-msssim
Imports
- ssim
from pytorch_msssim import ssim
- ms_ssim
from pytorch_msssim import ms_ssim
- SSIM
from pytorch_msssim import SSIM
- MS_SSIM
from pytorch_msssim import MS_SSIM
- pytorch_ssim
Quickstart
import torch
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
# Create two dummy image tensors (batch_size, channels, height, width)
# Images are typically non-negative, e.g., 0-255 or 0-1
X = torch.rand(4, 3, 256, 256) * 255 # Example: batch of 4 RGB images, 0-255 range
Y = torch.rand(4, 3, 256, 256) * 255 # Another batch for comparison
# Calculate SSIM and MS-SSIM values (per image in batch)
# data_range should match the maximum possible pixel value (e.g., 255 for 0-255 images)
ssim_val = ssim(X, Y, data_range=255, size_average=False) # Returns (N,) tensor
ms_ssim_val = ms_ssim(X, Y, data_range=255, size_average=False) # Returns (N,) tensor
print(f"SSIM values: {ssim_val}")
print(f"MS-SSIM values: {ms_ssim_val}")
# Using SSIM/MS_SSIM as a loss function (returns scalar mean loss)
# For loss, set size_average=True and typically use 1 - score
ssim_loss_module = SSIM(data_range=255, size_average=True, channel=3)
ms_ssim_loss_module = MS_SSIM(data_range=255, size_average=True, channel=3)
ssim_loss = 1 - ssim_loss_module(X, Y) # A scalar tensor
ms_ssim_loss = 1 - ms_ssim_loss_module(X, Y) # A scalar tensor
print(f"SSIM Loss: {ssim_loss}")
print(f"MS-SSIM Loss: {ms_ssim_loss}")