PyTorch MS-SSIM

1.0.0 · active · verified Sat Apr 11

pytorch-msssim provides a fast and differentiable implementation of Multi-Scale Structural Similarity (MS-SSIM) and Structural Similarity (SSIM) index for PyTorch. It is designed to be efficient by using separable Gaussian kernels. The library is currently at version 1.0.0, with releases occurring as needed rather than on a strict schedule.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to calculate SSIM and MS-SSIM between two batches of images and how to use the SSIM and MS_SSIM classes as loss functions. Ensure your input tensors `X` and `Y` are of shape `(N, C, H, W)` and `data_range` is set correctly for your pixel value range.

import torch
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

# Create two dummy image tensors (batch_size, channels, height, width)
# Images are typically non-negative, e.g., 0-255 or 0-1
X = torch.rand(4, 3, 256, 256) * 255 # Example: batch of 4 RGB images, 0-255 range
Y = torch.rand(4, 3, 256, 256) * 255 # Another batch for comparison

# Calculate SSIM and MS-SSIM values (per image in batch)
# data_range should match the maximum possible pixel value (e.g., 255 for 0-255 images)
ssim_val = ssim(X, Y, data_range=255, size_average=False) # Returns (N,) tensor
ms_ssim_val = ms_ssim(X, Y, data_range=255, size_average=False) # Returns (N,) tensor

print(f"SSIM values: {ssim_val}")
print(f"MS-SSIM values: {ms_ssim_val}")

# Using SSIM/MS_SSIM as a loss function (returns scalar mean loss)
# For loss, set size_average=True and typically use 1 - score
ssim_loss_module = SSIM(data_range=255, size_average=True, channel=3)
ms_ssim_loss_module = MS_SSIM(data_range=255, size_average=True, channel=3)

ssim_loss = 1 - ssim_loss_module(X, Y) # A scalar tensor
ms_ssim_loss = 1 - ms_ssim_loss_module(X, Y) # A scalar tensor

print(f"SSIM Loss: {ssim_loss}")
print(f"MS-SSIM Loss: {ms_ssim_loss}")

view raw JSON →