Segmentation Models PyTorch
Segmentation Models PyTorch (SMP) is a Python library offering a high-level API for various neural network architectures, pre-trained backbones, losses, and metrics for image semantic segmentation. It supports 12 encoder-decoder architectures and over 800 pre-trained convolutional and transformer-based encoders, leveraging `timm` for a vast selection. The library focuses on simplicity, fast convergence, and compatibility with PyTorch's `torch.jit.script`, `torch.compile`, and `torch.export` features. It is currently at version 0.5.0 and maintains an active release cadence with frequent updates and new model integrations.
Warnings
- breaking The `UperNet` model architecture underwent significant changes in v0.5.0, making model weights trained with v0.4.0 incompatible with v0.5.0. Existing UperNet models will need to be re-trained or adapted.
- deprecated Encoders from the `timm` library previously accessed with a `timm-` prefix (e.g., `timm-resnet34`) are deprecated in v0.5.0. The recommended way to use `timm` encoders is now with the `tu-` prefix (e.g., `tu-resnet34`).
- deprecated The `smp.utils.losses` module was deprecated in v0.2.0. All loss functions have been moved to the `smp.losses` module.
- breaking The minimum Python version requirement was increased from 3.6 to 3.7.
- gotcha To ensure compatibility with `albumentations` versions >= 1.4.0, some internal function names that interact with `albumentations` may have changed, requiring updates if you directly extended or modified SMP's data processing pipelines.
- gotcha For optimal performance, especially when using pre-trained encoders, it is crucial to apply the correct preprocessing steps (e.g., normalization, resizing) to your input data, matching how the encoder's weights were pre-trained. Use `smp.encoders.get_preprocessing_fn` for this.
- gotcha Some models may require input image dimensions to be a power of 2, or they may handle incorrect sizes with specific interpolation/padding methods. Unexpected input sizes can lead to errors or degraded performance.
Install
-
pip install segmentation-models-pytorch
Imports
- smp
import segmentation_models_pytorch as smp
- Unet
model = smp.Unet(...)
- get_preprocessing_fn
from segmentation_models_pytorch.encoders import get_preprocessing_fn
Quickstart
import torch
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn
# 1. Create segmentation model
model = smp.Unet(
encoder_name="resnet34", # choose encoder backbone
encoder_weights="imagenet", # use `imagenet` pre-trained weights
in_channels=3, # model input channels (3 for RGB)
classes=1, # model output channels (number of classes)
activation='sigmoid' # activation function for binary segmentation
)
# 2. Configure data preprocessing (important for pre-trained encoders)
preprocess_input = get_preprocessing_fn('resnet34', pretrained='imagenet')
# Example usage:
# Dummy input image (batch_size=1, channels=3, height=256, width=256)
image = torch.randn(1, 3, 256, 256)
# Apply preprocessing (e.g., normalization)
input_tensor = preprocess_input(image)
# Forward pass
model.eval()
with torch.no_grad():
predicted_mask = model(input_tensor)
print(f"Model output shape: {predicted_mask.shape}")