High-Performance Safetensors Model Loader

0.2.2 · active · verified Fri Apr 17

fastsafetensors is a Python library designed for high-performance loading of safetensors models, particularly optimized for GPU environments (CUDA, ROCm). It aims to offer faster loading times compared to the standard `safetensors` library for large models. The current version is `0.2.2`, and it maintains an active release cadence with frequent bug fixes and performance improvements.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to create a dummy safetensors file using the standard `safetensors` library, then load it with `fastsafetensors.FastSafetensorsFile`. It shows how to inspect the file's metadata and how to lazily load individual tensors by accessing them like dictionary items. Note that `torch` is used here for tensor creation and loading, implying it should be installed for this specific example.

import torch
from safetensors.torch import save_file
from fastsafetensors import FastSafetensorsFile
import os

# 1. Create a dummy safetensors file for demonstration
dummy_data = {
    "layer1.weight": torch.randn(128, 64),
    "layer1.bias": torch.zeros(128),
    "layer2.weight": torch.ones(64, 32)
}
dummy_file_path = "dummy_model.safetensors"
save_file(dummy_data, dummy_file_path)

print(f"Created dummy safetensors file: {dummy_file_path}\n")

# 2. Load the safetensors file using FastSafetensorsFile
try:
    fsf = FastSafetensorsFile(dummy_file_path)

    # 3. Inspect tensor metadata (does not load data into memory)
    print("Tensors available in the file (metadata only):")
    for name, metadata in fsf.get_tensors().items():
        print(f"  - {name}: {metadata}")

    # 4. Access a specific tensor (this triggers loading for that tensor)
    tensor_name = "layer1.weight"
    loaded_tensor = fsf[tensor_name]
    print(f"\nSuccessfully loaded '{tensor_name}':")
    print(f"  Type: {type(loaded_tensor)}")
    print(f"  Shape: {loaded_tensor.shape}")
    print(f"  First 5 elements:\n{loaded_tensor.flatten()[:5]}\n")

    # Access another tensor
    print(f"Accessing 'layer2.weight' (shape: {fsf['layer2.weight'].shape})\n")

except Exception as e:
    print(f"An error occurred: {e}")
finally:
    # 5. Clean up the dummy file
    if os.path.exists(dummy_file_path):
        os.remove(dummy_file_path)
        print(f"Cleaned up dummy file: {dummy_file_path}")

view raw JSON →