TensorDict

0.12.0 · active · verified Sat Apr 11

TensorDict is a PyTorch-dedicated tensor container, offering a dictionary-like class that inherits properties from `torch.Tensor`. It simplifies working with collections of tensors, enabling tensor-like operations, efficient data management, and re-usable training loops across various machine learning paradigms. It's currently at version 0.12.0 and is actively developed.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the creation of a TensorDict, moving it to a device, slicing it like a tensor, and stacking multiple TensorDicts.

import torch
from tensordict import TensorDict

data = TensorDict(
    obs=torch.randn(128, 84),
    action=torch.randn(128, 4),
    reward=torch.randn(128, 1),
    batch_size=[128],
)

device = "cuda" if torch.cuda.is_available() else "cpu"
data_gpu = data.to(device) # all tensors move together
sub = data_gpu[:64] # all tensors are sliced
stacked = torch.stack([data, data]) # works like a tensor

print(f"Original TensorDict:\n{data}")
print(f"Device of data_gpu: {data_gpu.device}")
print(f"Batch size of sliced TensorDict: {sub.batch_size}")
print(f"Batch size of stacked TensorDict: {stacked.batch_size}")

view raw JSON →