Torchvision

0.26.0 · active · verified Tue Mar 31

Torchvision is a PyTorch domain library providing popular datasets, model architectures, and common image and video transformations for computer vision tasks. It is actively maintained and releases are synchronized with PyTorch versions, with the current version 0.26.0 compatible with torch 2.11.0. It aims to simplify the data loading, preprocessing, and model development workflow for computer vision researchers and practitioners.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `torchvision` to preprocess an image and perform inference with a pre-trained ResNet-18 model. It covers defining transformations with `torchvision.transforms.v2.Compose`, loading a pre-trained model with `torchvision.models`, and obtaining human-readable predictions.

import torch
from torchvision.transforms import v2
from torchvision import models
import os

# 1. Create a dummy image tensor (simulating a loaded image)
H, W = 256, 256 # Example image dimensions
dummy_image = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)

# 2. Define image transforms using the recommended v2 API
preprocess = v2.Compose([
    v2.Resize((224, 224), antialias=True), # Resize for common model input sizes
    v2.ToDtype(torch.float32, scale=True), # Convert to float and scale pixel values to [0, 1]
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet normalization
])

# Apply transforms to the dummy image
input_tensor = preprocess(dummy_image)
input_batch = input_tensor.unsqueeze(0) # Add a batch dimension (models expect batches)

# 3. Load a pre-trained image classification model (e.g., ResNet-18)
# Use DEFAULT_WEIGHTS to automatically get the best available pre-trained weights
weights = models.ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)
model.eval() # Set the model to evaluation mode for inference

# Get the categories the model was trained on for human-readable output
categories = weights.meta["categories"]

# 4. Perform inference
with torch.no_grad(): # Disable gradient calculation for inference to save memory and computations
    output = model(input_batch)

# 5. Get the predicted class
probabilities = torch.nn.functional.softmax(output, dim=1)
predicted_probability, predicted_idx = torch.max(probabilities, 1)
predicted_label = categories[predicted_idx.item()]

print(f"Predicted class: {predicted_label} (Probability: {predicted_probability.item():.2f})")
print("Quickstart successful: Image processed and classified using torchvision.")

view raw JSON →