Segmentation Models PyTorch

0.5.0 · active · verified Mon Apr 13

Segmentation Models PyTorch (SMP) is a Python library offering a high-level API for various neural network architectures, pre-trained backbones, losses, and metrics for image semantic segmentation. It supports 12 encoder-decoder architectures and over 800 pre-trained convolutional and transformer-based encoders, leveraging `timm` for a vast selection. The library focuses on simplicity, fast convergence, and compatibility with PyTorch's `torch.jit.script`, `torch.compile`, and `torch.export` features. It is currently at version 0.5.0 and maintains an active release cadence with frequent updates and new model integrations.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to instantiate a U-Net model with a pre-trained ResNet34 encoder, configure input and output channels, and set an activation function. It also shows how to obtain and apply the correct preprocessing function required for models with pre-trained backbones to ensure optimal performance.

import torch
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

# 1. Create segmentation model
model = smp.Unet(
    encoder_name="resnet34",          # choose encoder backbone
    encoder_weights="imagenet",      # use `imagenet` pre-trained weights
    in_channels=3,                   # model input channels (3 for RGB)
    classes=1,                       # model output channels (number of classes)
    activation='sigmoid'             # activation function for binary segmentation
)

# 2. Configure data preprocessing (important for pre-trained encoders)
preprocess_input = get_preprocessing_fn('resnet34', pretrained='imagenet')

# Example usage:
# Dummy input image (batch_size=1, channels=3, height=256, width=256)
image = torch.randn(1, 3, 256, 256)

# Apply preprocessing (e.g., normalization)
input_tensor = preprocess_input(image)

# Forward pass
model.eval()
with torch.no_grad():
    predicted_mask = model(input_tensor)

print(f"Model output shape: {predicted_mask.shape}")

view raw JSON →