Tensorizer
Tensorizer is a Python library developed by CoreWeave for fast serialization and deserialization of PyTorch modules, models, and tensors. It aims to reduce model load times and CPU memory usage by efficiently streaming tensor data, supporting local filesystems, HTTP/HTTPS, and S3 endpoints. The current version is 2.12.0, with ongoing development and updates primarily driven by CoreWeave's needs for serving large AI models.
Common errors
-
AttributeError: 'SimpleModel' object has no attribute 'my_config_object'
cause Attempting to access a non-tensor attribute (e.g., a custom configuration object, tokenizer) that was part of the original PyTorch model but was not explicitly saved or serialized by `Tensorizer`, because `Tensorizer` only handles tensors by default.fixModify your serialization logic to separately save and load non-tensor attributes using standard Python serialization (e.g., `json`, `pickle` if secure, or other custom methods). Ensure your model re-initialization logic correctly reconstructs these attributes, or use `tensorizer.torch_compat` with awareness of its `pickle` usage. -
libsodium.so.23: cannot open shared object file: No such file or directory
cause The `libsodium` shared library, required for tensor encryption/decryption, is not installed or not discoverable in the system's library paths.fixInstall the `libsodium` development package on your system. For Debian/Ubuntu, use `sudo apt-get install libsodium-dev` or `libsodium23`. For other operating systems, consult the `libsodium` official documentation for installation instructions. -
TypeError: 'TensorDeserializer' object is not iterable
cause Attempting to iterate over a `TensorDeserializer` object directly, or treating it as a dictionary outside of its context manager usage, particularly when trying to load a state dictionary.fixEnsure `TensorDeserializer` is used as a context manager when loading state dictionaries or accessing its contents that behave like a dictionary: `with TensorDeserializer(uri) as state_dict: model.load_state_dict(state_dict)`.
Warnings
- gotcha Tensorizer explicitly serializes only tensors. Unlike `torch.save`, it does NOT use Python's `pickle` module for arbitrary Python objects. If your `torch.nn.Module` contains non-tensor attributes critical to its functionality (e.g., custom configuration objects, tokenizers), these will NOT be saved by `TensorSerializer.write_module` or `write_state_dict` and will be missing upon deserialization.
- security When using the `tensorizer.torch_compat` module as a drop-in replacement for `torch.save` and `torch.load`, `torch.load` still uses the `pickle` module internally for any non-tensor data. Loading untrusted pickled files can lead to arbitrary code execution, posing a security risk. This warning applies even though `tensorizer` itself only handles data.
- gotcha Tensor encryption/decryption functionality requires the external `libsodium` library to be installed on your system. Without it, attempts to use encryption features will fail.
- gotcha `TensorDeserializer` is designed as a context manager, especially important for lazy loading. While it can be instantiated without `with`, using it as a context manager ensures proper resource cleanup and file closing, preventing potential file handle leaks or unexpected behavior, especially when streaming from remote sources.
- gotcha Loading and serializing very large models (e.g., EleutherAI/gpt-j-6B) with `tensorizer` requires substantial CPU RAM (up to ~20GB) and GPU VRAM (~16GB), even with `tensorizer`'s efficiency. Ensure your environment has sufficient resources.
Install
-
pip install tensorizer -
pip install git+https://github.com/coreweave/tensorizer
Imports
- TensorSerializer
from tensorizer import TensorSerializer
- TensorDeserializer
from tensorizer import TensorDeserializer
- tensorizer_saving
from tensorizer.torch_compat import tensorizer_saving
- tensorizer_loading
from tensorizer.torch_compat import tensorizer_loading
Quickstart
import torch
from torch import nn
from tensorizer import TensorSerializer, TensorDeserializer
import os
# 1. Define a simple PyTorch model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 50)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(50, 2)
self.dummy_attribute = 'some_string_data' # Non-tensor attribute
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
model = SimpleModel()
dummy_input = torch.randn(1, 10)
original_output = model(dummy_input)
file_path = "./simple_model.tensors"
# 2. Serialize the model's state_dict
print(f"Serializing model to {file_path}...")
serializer = TensorSerializer(file_path)
serializer.write_state_dict(model.state_dict())
serializer.close()
print("Serialization complete.")
# 3. Create a new model instance for deserialization
loaded_model = SimpleModel()
# 4. Deserialise the model's state_dict into the new instance
print(f"Deserializing model from {file_path}...")
with TensorDeserializer(file_path, device='cpu') as loaded_state_dict:
loaded_model.load_state_dict(loaded_state_dict)
print("Deserialization complete.")
# Verify deserialized model output
loaded_output = loaded_model(dummy_input)
assert torch.allclose(original_output, loaded_output), "Outputs do not match after serialization/deserialization!"
print("Model successfully serialized and deserialized with matching outputs.")
# Clean up the created file
os.remove(file_path)