TensorDict Nightly
TensorDict is a PyTorch-dedicated tensor container that provides a dictionary-like class inheriting properties from `torch.Tensor`. It streamlines the organization and manipulation of collections of tensors, enabling efficient batch operations, shape transformations, and seamless device management. As a nightly build, `tensordict-nightly` offers the latest features and bug fixes, with frequent updates that may introduce breaking changes.
Common errors
-
IndexError: tuple index out of range
cause Passing a dictionary with non-string keys (e.g., integers, tuples) to the `TensorDict` constructor or `make_tensordict` function.fixEnsure all keys in the input dictionary are strings. If you need nested keys, use string keys or consider using `flatten_keys` later with a separator. -
RuntimeError: Cannot modify locked TensorDict.
cause Attempting to modify a `TensorDict` instance that has been locked (e.g., by calling `td.lock_()` or being created with `lock=True`).fixUnlock the TensorDict using `td.unlock_()` before modification, or use in-place methods with a trailing underscore (e.g., `td.set_(key, value)`) if the key already exists. -
Wrong values in TensorDict with device='cpu' specified
cause When creating or moving a `TensorDict` to CPU with `non_blocking=True` (which is often implicit or default), data transfer might not be synchronized, leading to incorrect values if accessed immediately. This is more prevalent with non-CUDA devices.fixExplicitly set `non_blocking=False` when moving to CPU if immediate and synchronized access is critical, or ensure a synchronization call (`torch.cuda.synchronize()` if applicable) is made before accessing the data. -
_dist_sample hasattr error OR load_state_dict fails if checkpoint lacks entries for TensorDictParams
cause These are general issues often related to specific versions of PyTorch or `tensordict`, or complex model architectures and distributed setups where `TensorDict` or `tensorclass` objects are used in state dictionaries or sampling processes.fixCheck the `pytorch/tensordict` GitHub issues for similar reports and potential workarounds or targeted bug fixes for your specific `tensordict` and PyTorch versions. Consider updating to the latest nightly builds for potential fixes.
Warnings
- breaking Python 3.9 support was dropped in TensorDict v0.11.0. Python 3.10 or newer is now required.
- breaking Deprecated methods `lock`, `unlock`, and `rename_key` (without a trailing underscore) were removed in v0.11.0.
- breaking The `MemoryMappedTensor._tensor` property now raises a `RuntimeError` since v0.11.0. Users should interact with the `MemoryMappedTensor` instance directly as it is a tensor subclass.
- gotcha From v0.10.0, lists assigned to a TensorDict will be automatically stacked by default, potentially raising a `FutureWarning`.
- gotcha Calling `.to()` method on a `TensorDict` with a `dtype` argument will raise an error. The `to()` method is for device casting only.
Install
-
pip install tensordict-nightly -
pip install tensordict-nightly --no-deps # for uv + PyTorch nightlies
Imports
- TensorDict
from tensordict import TensorDict
- MemoryMappedTensor
my_memmap_tensor._tensor
from tensordict import MemoryMappedTensor
- TensorDictModule
from tensordict.nn import TensorDictModule
- tensorclass
from tensordict import tensorclass
Quickstart
import torch
from tensordict import TensorDict
# Create a TensorDict with a specified batch_size
td = TensorDict(
{"observations": torch.randn(128, 84),
"actions": torch.randn(128, 4)},
batch_size=[128]
)
print("Original TensorDict:\n", td)
print("Batch size:", td.batch_size)
# Accessing elements
obs = td["observations"]
print("\nObservations shape:", obs.shape)
# Adding a new key
td["rewards"] = torch.randn(128, 1)
print("\nTensorDict after adding rewards:\n", td)
# Moving to device
if torch.cuda.is_available():
td_gpu = td.to("cuda")
print(f"\nTensorDict moved to {td_gpu.device}:\n", td_gpu)
# Slicing
sub_td = td[:64]
print("\nSliced TensorDict (first 64 elements):\n", sub_td)
print("Sliced batch size:", sub_td.batch_size)