spandrel-extra-arches

0.2.0 · active · verified Thu Apr 16

spandrel-extra-arches is a Python library that implements various PyTorch model architectures with restrictive licenses, designed to extend the capabilities of the core `spandrel` package. It enables `spandrel` to automatically detect and load models that might have non-commercial or other specific license requirements, segregating them from `spandrel`'s permissively licensed architectures. The current version is 0.2.0, released on September 15, 2024, and it is actively maintained and developed in conjunction with `spandrel`.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use the `spandrel` library to load a model file, which would implicitly utilize the architectures provided by `spandrel-extra-arches` if one of its models is detected. The `spandrel` library handles the automatic detection and instantiation of the correct architecture.

import torch
from spandrel import Spandrel

# Ensure spandrel-extra-arches is installed (pip install spandrel-extra-arches)
# spandrel will automatically discover architectures provided by spandrel-extra-arches

# Example: Assuming you have a .pth model file for an architecture in spandrel-extra-arches
# For demonstration, let's pretend 'my_extra_arch_model.pth' is an RCAN model.
# In a real scenario, `spandrel.Spandrel()` would detect the architecture type.

try:
    # This path would be to an actual model file
    model_path = 'my_extra_arch_model.pth'
    # Create a dummy model file for demonstration if it doesn't exist
    if not torch.exists(model_path):
        # This is a placeholder. A real model would be loaded from a file.
        # In practice, you'd download a .pth file for an extra arch model.
        print(f"[INFO] No model file found at {model_path}. Skipping model loading example.")
        print("[INFO] Please provide a valid .pth model file for an architecture supported by spandrel-extra-arches to run this part.")
    else:
        # Load the model using spandrel
        spandrel_model = Spandrel(model_path)
        print(f"Model loaded successfully: {spandrel_model.name} - {spandrel_model.arch_name}")
        
        # Example of getting the underlying PyTorch module
        pytorch_module = spandrel_model.model
        print(f"PyTorch module type: {type(pytorch_module)}")

except Exception as e:
    print(f"An error occurred during model loading: {e}")

view raw JSON →