Spandrel: PyTorch Model Architecture Support
Spandrel is a Python library that provides support for loading and running pre-trained PyTorch models, particularly those used in AI Super-Resolution, restoration, and inpainting. It automatically detects the model architecture and hyperparameters from various model file types, including `.pth`, `.pt`, `.ckpt`, and `.safetensors`. The current version is 0.4.2, and it follows an active release cadence with updates often accompanied by the `spandrel_extra_arches` package for models with restrictive licenses.
Warnings
- gotcha Loading `.pth` files using Python's `pickle` module, which `spandrel` utilizes, poses a security risk due to its vulnerability to arbitrary code execution. `spandrel` attempts to mitigate this by only deserializing certain data types, but it does not fully eliminate the risk. Only load `.pth` files from trusted sources.
- gotcha The `spandrel` package contains architectures with permissive licenses (MIT, Apache 2.0, public domain). For architectures with more restrictive licenses (e.g., non-commercial), you need to install `spandrel_extra_arches`. Be aware of the licensing implications of `spandrel_extra_arches` for commercial or closed-source projects.
- breaking Starting from v0.4.0, all architectures in `spandrel` (and `spandrel_extra_arches`) require keyword arguments for their `load` methods. This change improves clarity and prevents errors when calling architecture-specific loading functions.
- gotcha The `ModelDescriptor` and its variants (e.g., `ImageModelDescriptor`) provided by `spandrel` do not perform image-to-tensor conversion. You must pre-process your image data into a `torch.Tensor` with the correct shape and data type before passing it to the model for inference.
Install
-
pip install spandrel -
pip install spandrel spandrel_extra_arches
Imports
- ModelLoader
from spandrel import ModelLoader
Quickstart
import os
import torch
from spandrel import ModelLoader
# This is a placeholder for a real .pth model file.
# In a real scenario, you would have a path to a pre-trained PyTorch model.
# For demonstration, we'll simulate loading a non-existent file.
# Replace 'path/to/your/model.pth' with an actual model file path.
# A common practice is to download models from official repositories.
model_path = os.environ.get('SPANDREL_MODEL_PATH', 'path/to/your/model.pth')
# Ensure the directory for the dummy model exists if needed for testing
# For a real quickstart, the model_path would point to an existing file.
if not os.path.exists(model_path):
print(f"[NOTE]: Model file not found at '{model_path}'. This example requires a valid .pth model file.\n")
print("You can download a sample model, e.g., from a Super-Resolution project, and update 'model_path'.")
# Simulate a dummy model for demonstration purposes if no file exists
# This part would typically not be in a quickstart as it expects a real file.
# For the purpose of making this runnable *without* an actual file,
# we'll create a minimal placeholder for the ModelLoader.load_from_file call to fail gracefully.
try:
# Attempt to load, expecting failure without a real file
model = ModelLoader.load_from_file(model_path)
# If it miraculously works (e.g., user provided a path to a dummy file),
# then proceed to describe interaction.
print(f"Successfully loaded model: {model.name}")
# ModelDescriptor objects (like ImageModelDescriptor) are wrappers around the actual PyTorch model.
# They provide a unified interface.
# The actual forward pass depends on the model type.
# For an ImageModelDescriptor, input is typically a torch.Tensor (batch, channels, height, width).
# dummy_input = torch.randn(1, 3, 256, 256) # Example input for an image model
# output = model(dummy_input)
# print(f"Model output shape: {output.shape}")
except FileNotFoundError:
print("Failed to load model as expected, because the file does not exist.")
print("Please provide a real .pth model path for a functional example.")
except Exception as e:
print(f"An error occurred during model loading: {e}")
else:
try:
model = ModelLoader.load_from_file(model_path)
print(f"Successfully loaded model: {model.name} (architecture: {model.architecture.name})")
# Example of accessing metadata
if hasattr(model, 'scale'):
print(f"Model scale: {model.scale}x")
if hasattr(model, 'upscale_latent'): # Specific to certain architectures
print(f"Upscale latent: {model.upscale_latent}")
# Note: ImageModelDescriptor will NOT convert an image to a tensor.
# You need to provide a pre-processed tensor.
# For demonstration, we'll create a dummy input tensor if the model expects one.
if 'Image' in str(type(model)) or 'Upscaler' in str(type(model)):
# Assuming a common image input format: NCHW (batch, channels, height, width)
dummy_input = torch.randn(1, 3, 128, 128) # Example: batch size 1, 3 channels, 128x128 image
print(f"Attempting forward pass with dummy input shape: {dummy_input.shape}")
try:
output = model(dummy_input)
print(f"Model forward pass successful. Output shape: {output.shape}")
except Exception as e:
print(f"Error during model forward pass with dummy input: {e}")
print("The actual input shape and type depend on the specific model architecture.")
else:
print("Model type not recognized for dummy image input. Skipping forward pass.")
except Exception as e:
print(f"Error loading or interacting with model from '{model_path}': {e}")