Tensorizer

2.12.0 · active · verified Thu Apr 16

Tensorizer is a Python library developed by CoreWeave for fast serialization and deserialization of PyTorch modules, models, and tensors. It aims to reduce model load times and CPU memory usage by efficiently streaming tensor data, supporting local filesystems, HTTP/HTTPS, and S3 endpoints. The current version is 2.12.0, with ongoing development and updates primarily driven by CoreWeave's needs for serving large AI models.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to serialize and deserialize a PyTorch model's state dictionary using `TensorSerializer` and `TensorDeserializer`. It creates a simple neural network, saves its state, loads it into a new instance, and verifies that the outputs match. The `TensorDeserializer` is used as a context manager for proper resource handling.

import torch
from torch import nn
from tensorizer import TensorSerializer, TensorDeserializer
import os

# 1. Define a simple PyTorch model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 50)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(50, 2)
        self.dummy_attribute = 'some_string_data' # Non-tensor attribute

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

model = SimpleModel()
dummy_input = torch.randn(1, 10)
original_output = model(dummy_input)

file_path = "./simple_model.tensors"

# 2. Serialize the model's state_dict
print(f"Serializing model to {file_path}...")
serializer = TensorSerializer(file_path)
serializer.write_state_dict(model.state_dict())
serializer.close()
print("Serialization complete.")

# 3. Create a new model instance for deserialization
loaded_model = SimpleModel()

# 4. Deserialise the model's state_dict into the new instance
print(f"Deserializing model from {file_path}...")
with TensorDeserializer(file_path, device='cpu') as loaded_state_dict:
    loaded_model.load_state_dict(loaded_state_dict)
print("Deserialization complete.")

# Verify deserialized model output
loaded_output = loaded_model(dummy_input)
assert torch.allclose(original_output, loaded_output), "Outputs do not match after serialization/deserialization!"
print("Model successfully serialized and deserialized with matching outputs.")

# Clean up the created file
os.remove(file_path)

view raw JSON →