TensorDict
TensorDict is a PyTorch-dedicated tensor container, offering a dictionary-like class that inherits properties from `torch.Tensor`. It simplifies working with collections of tensors, enabling tensor-like operations, efficient data management, and re-usable training loops across various machine learning paradigms. It's currently at version 0.12.0 and is actively developed.
Warnings
- breaking Python 3.9 support was dropped in TensorDict v0.11.0. Python 3.10 or newer is now required.
- breaking Deprecated methods `lock`, `unlock`, and `rename_key` (without a trailing underscore) were removed in v0.11.0. Use `lock_`, `unlock_`, and `rename_key_` instead for in-place modifications.
- breaking The `MemoryMappedTensor._tensor` property now raises a `RuntimeError` since v0.11.0. Users should interact with the `MemoryMappedTensor` instance directly as it is a tensor subclass.
- gotcha From v0.10.0, lists assigned to a TensorDict will be automatically stacked by default, potentially raising a `FutureWarning`. Explicit context managers should be used for specific behavior.
Install
-
pip install tensordict -
conda install -c conda-forge tensordict
Imports
- TensorDict
from tensordict import TensorDict
- tensorclass
from tensordict import tensorclass
- MemoryMappedTensor
from tensordict import MemoryMappedTensor
Quickstart
import torch
from tensordict import TensorDict
data = TensorDict(
obs=torch.randn(128, 84),
action=torch.randn(128, 4),
reward=torch.randn(128, 1),
batch_size=[128],
)
device = "cuda" if torch.cuda.is_available() else "cpu"
data_gpu = data.to(device) # all tensors move together
sub = data_gpu[:64] # all tensors are sliced
stacked = torch.stack([data, data]) # works like a tensor
print(f"Original TensorDict:\n{data}")
print(f"Device of data_gpu: {data_gpu.device}")
print(f"Batch size of sliced TensorDict: {sub.batch_size}")
print(f"Batch size of stacked TensorDict: {stacked.batch_size}")