Spandrel: PyTorch Model Architecture Support

0.4.2 · active · verified Sat Apr 11

Spandrel is a Python library that provides support for loading and running pre-trained PyTorch models, particularly those used in AI Super-Resolution, restoration, and inpainting. It automatically detects the model architecture and hyperparameters from various model file types, including `.pth`, `.pt`, `.ckpt`, and `.safetensors`. The current version is 0.4.2, and it follows an active release cadence with updates often accompanied by the `spandrel_extra_arches` package for models with restrictive licenses.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to load a pre-trained PyTorch model using `ModelLoader.load_from_file`. The `ModelLoader` automatically detects the architecture and provides a unified `ModelDescriptor` object. The example includes a placeholder for a model path and demonstrates how to interact with the loaded model, including accessing metadata and attempting a forward pass with a dummy tensor input. Users should replace the placeholder path with a real `.pth` model file.

import os
import torch
from spandrel import ModelLoader

# This is a placeholder for a real .pth model file.
# In a real scenario, you would have a path to a pre-trained PyTorch model.
# For demonstration, we'll simulate loading a non-existent file.
# Replace 'path/to/your/model.pth' with an actual model file path.
# A common practice is to download models from official repositories.
model_path = os.environ.get('SPANDREL_MODEL_PATH', 'path/to/your/model.pth')

# Ensure the directory for the dummy model exists if needed for testing
# For a real quickstart, the model_path would point to an existing file.
if not os.path.exists(model_path):
    print(f"[NOTE]: Model file not found at '{model_path}'. This example requires a valid .pth model file.\n")
    print("You can download a sample model, e.g., from a Super-Resolution project, and update 'model_path'.")
    # Simulate a dummy model for demonstration purposes if no file exists
    # This part would typically not be in a quickstart as it expects a real file.
    # For the purpose of making this runnable *without* an actual file,
    # we'll create a minimal placeholder for the ModelLoader.load_from_file call to fail gracefully.
    try:
        # Attempt to load, expecting failure without a real file
        model = ModelLoader.load_from_file(model_path)
        # If it miraculously works (e.g., user provided a path to a dummy file),
        # then proceed to describe interaction.
        print(f"Successfully loaded model: {model.name}")
        # ModelDescriptor objects (like ImageModelDescriptor) are wrappers around the actual PyTorch model.
        # They provide a unified interface.
        # The actual forward pass depends on the model type.
        # For an ImageModelDescriptor, input is typically a torch.Tensor (batch, channels, height, width).
        # dummy_input = torch.randn(1, 3, 256, 256) # Example input for an image model
        # output = model(dummy_input)
        # print(f"Model output shape: {output.shape}")
    except FileNotFoundError:
        print("Failed to load model as expected, because the file does not exist.")
        print("Please provide a real .pth model path for a functional example.")
    except Exception as e:
        print(f"An error occurred during model loading: {e}")
else:
    try:
        model = ModelLoader.load_from_file(model_path)
        print(f"Successfully loaded model: {model.name} (architecture: {model.architecture.name})")
        # Example of accessing metadata
        if hasattr(model, 'scale'):
            print(f"Model scale: {model.scale}x")
        if hasattr(model, 'upscale_latent'): # Specific to certain architectures
            print(f"Upscale latent: {model.upscale_latent}")

        # Note: ImageModelDescriptor will NOT convert an image to a tensor.
        # You need to provide a pre-processed tensor.
        # For demonstration, we'll create a dummy input tensor if the model expects one.
        if 'Image' in str(type(model)) or 'Upscaler' in str(type(model)):
            # Assuming a common image input format: NCHW (batch, channels, height, width)
            dummy_input = torch.randn(1, 3, 128, 128) # Example: batch size 1, 3 channels, 128x128 image
            print(f"Attempting forward pass with dummy input shape: {dummy_input.shape}")
            try:
                output = model(dummy_input)
                print(f"Model forward pass successful. Output shape: {output.shape}")
            except Exception as e:
                print(f"Error during model forward pass with dummy input: {e}")
                print("The actual input shape and type depend on the specific model architecture.")
        else:
            print("Model type not recognized for dummy image input. Skipping forward pass.")

    except Exception as e:
        print(f"Error loading or interacting with model from '{model_path}': {e}")

view raw JSON →