TorchRL
TorchRL is an open-source, PyTorch-native library for Reinforcement Learning (RL). It provides modular, primitive-first abstractions for building efficient and flexible RL solutions, focusing on research and rapid prototyping. The library offers components for environments, data collection, replay buffers, policy and value networks, and loss functions, all designed to integrate seamlessly with the PyTorch ecosystem. It is currently at version 0.11.1 and follows a regular release cadence, often synced with PyTorch releases.
Warnings
- breaking In TorchRL v0.11.0, the collector codebase underwent a major refactoring. Existing implementations of collectors, especially `SyncDataCollector`, `MultiSyncDataCollector`, and `aSyncDataCollector`, may require updates to align with the new modular structure.
- breaking TorchRL v0.11.0 removed several deprecated features, replacing previous warnings with errors. This includes `KLRewardTransform` (use `torchrl.envs.llm.KLRewardTransform`), `LogReward` and `Recorder` (use `LogScalar` and `LogValidationReward`), and `unbatched_*_spec` properties from `VmasWrapper`/`VmasEnv` (use `full_*_spec_unbatched`).
- gotcha Using TorchRL with PyTorch versions older than 2.0 (e.g., PyTorch 1.12 with Python 3.7) can lead to `ImportError: undefined symbol` errors when installing the stable `torchrl` package.
- gotcha In TorchRL versions prior to 0.7.2, a critical issue existed where incorrect device settings in `ParallelEnv` could prevent tensors in buffers from being properly cloned, causing rollouts to return the same tensor instances across steps and potentially leading to incorrect behavior.
- deprecated The `PPOLoss` class in TorchRL v0.11.0 issues a warning regarding the use of `critic_network` directly and suggests using the `critic_coeff` argument instead for better control over the critic's contribution to the loss.
Install
-
pip install torchrl -
pip install tensordict-nightly torchrl-nightly
Imports
- TensorDict
from tensordict import TensorDict
- GymEnv
from torchrl.envs import GymEnv
- MLP
from torchrl.modules import MLP
- QValueActor
from torchrl.modules import QValueActor
- PPOLoss
from torchrl.objectives import PPOLoss
- SyncDataCollector
from torchrl.collectors import SyncDataCollector
Quickstart
import torch
from torchrl.envs import GymEnv
from torchrl.modules import MLP, QValueActor
from tensordict import TensorDict
# 1. Define the environment
env = GymEnv("CartPole-v1")
# 2. Create the policy (Q-value actor with an MLP backbone)
actor = QValueActor(
MLP(
in_features=env.observation_spec["observation"].shape[-1],
out_features=env.action_spec.shape[-1] if env.action_spec.shape else 2,
num_cells=[64, 64],
),
in_keys=["observation"],
spec=env.action_spec,
)
# 3. Collect a trajectory
rollout = env.rollout(max_steps=200, policy=actor)
# Print collected info
print(f"Collected {rollout.shape[0]} steps, total reward: {rollout['next', 'reward'].sum().item():.0f}")
print(f"Rollout keys: {rollout.keys()}")
print(f"Example observation shape: {rollout['observation'].shape}")
env.close()