Torchvision
Torchvision is a PyTorch domain library providing popular datasets, model architectures, and common image and video transformations for computer vision tasks. It is actively maintained and releases are synchronized with PyTorch versions, with the current version 0.26.0 compatible with torch 2.11.0. It aims to simplify the data loading, preprocessing, and model development workflow for computer vision researchers and practitioners.
Warnings
- breaking The video decoding and encoding utilities (`torchvision.io.video.*`, `read_video`, `write_video`, `VideoReader` class) were removed in Torchvision 0.26.0.
- deprecated The video decoding and encoding capabilities of TorchVision were deprecated starting from version 0.22 and were slated for removal. While initially targeted for 0.25, they were fully removed in 0.26.0.
- gotcha Since version 0.25.0, KeyPoints are no longer clamped by default after a transform. This is a behavior change from previous versions.
- gotcha The `torchvision.transforms.v2` API is the recommended and actively developed set of transforms. It offers better performance and supports transforming not just images, but also bounding boxes, masks, videos, and keypoints simultaneously.
- gotcha A version mismatch between `torch` and `torchvision` is a common cause of runtime errors (e.g., 'undefined symbol', 'CUDA toolkit version is incompatible').
- gotcha Since v0.8.0, all random transformations in `torchvision.transforms` use PyTorch's default random generator (`torch.manual_seed`) instead of Python's `random` module. Setting `random.seed()` will not affect these transforms.
Install
-
pip install torchvision -
pip install torch==X.Y.Z+cuXXX torchvision==A.B.C+cuXXX -f https://download.pytorch.org/whl/torch_stable.html
Imports
- transforms
from torchvision import transforms
- v2 (recommended transforms)
from torchvision.transforms import v2
- datasets
from torchvision import datasets
- models
from torchvision import models
- io (image I/O)
from torchvision import io
- tv_tensors
from torchvision import tv_tensors
Quickstart
import torch
from torchvision.transforms import v2
from torchvision import models
import os
# 1. Create a dummy image tensor (simulating a loaded image)
H, W = 256, 256 # Example image dimensions
dummy_image = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
# 2. Define image transforms using the recommended v2 API
preprocess = v2.Compose([
v2.Resize((224, 224), antialias=True), # Resize for common model input sizes
v2.ToDtype(torch.float32, scale=True), # Convert to float and scale pixel values to [0, 1]
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet normalization
])
# Apply transforms to the dummy image
input_tensor = preprocess(dummy_image)
input_batch = input_tensor.unsqueeze(0) # Add a batch dimension (models expect batches)
# 3. Load a pre-trained image classification model (e.g., ResNet-18)
# Use DEFAULT_WEIGHTS to automatically get the best available pre-trained weights
weights = models.ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)
model.eval() # Set the model to evaluation mode for inference
# Get the categories the model was trained on for human-readable output
categories = weights.meta["categories"]
# 4. Perform inference
with torch.no_grad(): # Disable gradient calculation for inference to save memory and computations
output = model(input_batch)
# 5. Get the predicted class
probabilities = torch.nn.functional.softmax(output, dim=1)
predicted_probability, predicted_idx = torch.max(probabilities, 1)
predicted_label = categories[predicted_idx.item()]
print(f"Predicted class: {predicted_label} (Probability: {predicted_probability.item():.2f})")
print("Quickstart successful: Image processed and classified using torchvision.")