Stable Baselines3 Contrib

2.8.0 · active · verified Thu Apr 16

sb3-contrib is the experimental contribution package for Stable Baselines3, providing additional reinforcement learning algorithms and features not yet integrated into the main SB3 library. It is currently at version 2.8.0 and typically releases new versions in sync with Stable Baselines3's major and minor updates, often introducing breaking changes related to Python or SB3 dependency versions.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to set up a vectorized Gymnasium environment and train an ARS (Augmented Random Search) agent from sb3-contrib. It covers environment creation, model initialization, training, and basic evaluation.

import gymnasium as gym
from sb3_contrib import ARS
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecMonitor

# 1. Create a vectorized environment
env_id = "CartPole-v1"
vec_env = make_vec_env(env_id, n_envs=4, seed=0)
vec_env = VecMonitor(vec_env) # Recommended wrapper for logging

# 2. Initialize the ARS agent
# ARS (Augmented Random Search) is a policy-gradient-free algorithm
model = ARS("MlpPolicy", vec_env, verbose=1)

# 3. Train the agent
print("Training the ARS model...")
model.learn(total_timesteps=10000)
print("Training finished.")

# 4. Save and load the model (optional)
model.save("ars_cartpole")
del model # remove to demonstrate loading
model = ARS.load("ars_cartpole")

# 5. Evaluate the trained agent
print("Evaluating the trained model...")
obs, info = vec_env.reset()
for _ in range(100): # Run for 100 steps
    action, _states = model.predict(obs, deterministic=True)
    obs, rewards, dones, infos = vec_env.step(action)
    # Handle episode termination for vectorized environments
    for i, done in enumerate(dones):
        if done:
            print(f"Episode finished, reward: {infos[i]['episode']['r']:.2f}")
vec_env.close()

view raw JSON →