TorchRL

0.11.1 · active · verified Mon Apr 13

TorchRL is an open-source, PyTorch-native library for Reinforcement Learning (RL). It provides modular, primitive-first abstractions for building efficient and flexible RL solutions, focusing on research and rapid prototyping. The library offers components for environments, data collection, replay buffers, policy and value networks, and loss functions, all designed to integrate seamlessly with the PyTorch ecosystem. It is currently at version 0.11.1 and follows a regular release cadence, often synced with PyTorch releases.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to create a simple Gym environment, define a Q-value policy using an MLP, and collect a trajectory with a specified maximum number of steps. The collected data is stored in a TensorDict.

import torch
from torchrl.envs import GymEnv
from torchrl.modules import MLP, QValueActor
from tensordict import TensorDict

# 1. Define the environment
env = GymEnv("CartPole-v1")

# 2. Create the policy (Q-value actor with an MLP backbone)
actor = QValueActor(
    MLP(
        in_features=env.observation_spec["observation"].shape[-1],
        out_features=env.action_spec.shape[-1] if env.action_spec.shape else 2,
        num_cells=[64, 64],
    ),
    in_keys=["observation"],
    spec=env.action_spec,
)

# 3. Collect a trajectory
rollout = env.rollout(max_steps=200, policy=actor)

# Print collected info
print(f"Collected {rollout.shape[0]} steps, total reward: {rollout['next', 'reward'].sum().item():.0f}")
print(f"Rollout keys: {rollout.keys()}")
print(f"Example observation shape: {rollout['observation'].shape}")

env.close()

view raw JSON →