Stable Baselines3

2.8.0 · active · verified Sun Apr 12

Stable Baselines3 (SB3) is a comprehensive Python library offering reliable implementations of reinforcement learning (RL) algorithms in PyTorch. It provides a clean and simple API, adhering to a scikit-learn-like syntax for training, evaluating, and deploying RL agents. SB3 is actively maintained with frequent releases, supporting state-of-the-art model-free RL algorithms like A2C, PPO, SAC, DQN, and TD3.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to create a Gymnasium environment, instantiate an A2C agent, train it for a specified number of timesteps, save and load the trained model, and finally evaluate its performance.

import gymnasium as gym
from stable_baselines3 import A2C

# Create environment
env = gym.make("CartPole-v1")

# Instantiate the agent
model = A2C("MlpPolicy", env, verbose=1)

# Train the agent
model.learn(total_timesteps=10000)

# Save the model
model.save("a2c_cartpole")

# Delete model and reload it to demonstrate saving and loading
del model
model = A2C.load("a2c_cartpole")

# Evaluate the trained agent
obs, info = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        obs, info = env.reset()
env.close()

view raw JSON →