Stable Baselines3 Contrib
sb3-contrib is the experimental contribution package for Stable Baselines3, providing additional reinforcement learning algorithms and features not yet integrated into the main SB3 library. It is currently at version 2.8.0 and typically releases new versions in sync with Stable Baselines3's major and minor updates, often introducing breaking changes related to Python or SB3 dependency versions.
Common errors
-
ImportError: cannot import name 'QRDQN' from 'sb3_contrib'
cause `QRDQN` was moved from `sb3-contrib` to `stable_baselines3`'s core library in SB3 v2.0.fixChange your import statement to `from stable_baselines3 import QRDQN`. -
AttributeError: 'VecMonitor' object has no attribute 'action_masks'
cause When using `MaskablePPO` or `RecurrentPPO`, the environment (or its wrapper chain) must implement an `action_masks()` method to provide valid action masks.fixImplement the `action_masks()` method in your custom `gymnasium.Env` or a custom `gymnasium.Wrapper` around your environment. If using a `VecEnv`, ensure the underlying environments provide masks and they are passed correctly, for `model.predict` you'll likely need to pass `action_masks` explicitly. -
stable_baselines3.common.utils.SB3DeprecationWarning: You are using an outdated version of Stable-Baselines3
cause `sb3-contrib` requires a specific, often very recent, version of `stable-baselines3` for compatibility.fixUpgrade both packages to their latest compatible versions: `pip install --upgrade stable-baselines3 sb3-contrib`. -
TypeError: zip() argument 'strict' must be bool, not None
cause From `sb3-contrib` v2.8.0, `zip` calls were updated to use `strict=True`, which will raise an error if sequences being zipped have different lengths. This often points to inconsistencies in environment observation/action spaces or custom data.fixReview the definitions of your environment's observation and action spaces, and any custom data structures passed into the model, to ensure all sequences that are zipped together are of consistent and expected lengths.
Warnings
- breaking Python 3.9 support was removed in v2.8.0. Earlier versions (v2.5.0, v2.1.0) dropped support for Python 3.8 and 3.7 respectively. Ensure your Python version meets the minimum requirement (>=3.10 for v2.8.0).
- breaking sb3-contrib is tightly coupled with `stable-baselines3`. New versions of `sb3-contrib` frequently require specific, often newer, versions of `stable-baselines3` (e.g., v2.8.0 requires SB3 >= 2.8.0).
- gotcha The `QRDQN` algorithm was originally in `sb3-contrib` but was moved to the core `stable_baselines3` library starting with SB3 v2.0. Attempting to import it from `sb3_contrib` will result in an `ImportError`.
- gotcha Algorithms like `MaskablePPO` and `RecurrentPPO` require the environment to implement an `action_masks()` method, which returns a boolean numpy array indicating valid actions. This is not a standard `gymnasium.Env` or `VecEnv` feature.
- breaking The default `learning_starts` parameter for `QRDQN` was significantly changed in `sb3-contrib` v2.3.0 (from 50_000 to 100) to align with other off-policy algorithms. This can drastically alter training behavior if not explicitly set.
Install
-
pip install sb3-contrib stable-baselines3 gymnasium
Imports
- MaskablePPO
from sb3_contrib import MaskablePPO
- RecurrentPPO
from sb3_contrib import RecurrentPPO
- ARS
from sb3_contrib import ARS
- TRPO
from sb3_contrib import TRPO
- CrossQ
from sb3_contrib import CrossQ
- QRDQN
from sb3_contrib import QRDQN
from stable_baselines3 import QRDQN
Quickstart
import gymnasium as gym
from sb3_contrib import ARS
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecMonitor
# 1. Create a vectorized environment
env_id = "CartPole-v1"
vec_env = make_vec_env(env_id, n_envs=4, seed=0)
vec_env = VecMonitor(vec_env) # Recommended wrapper for logging
# 2. Initialize the ARS agent
# ARS (Augmented Random Search) is a policy-gradient-free algorithm
model = ARS("MlpPolicy", vec_env, verbose=1)
# 3. Train the agent
print("Training the ARS model...")
model.learn(total_timesteps=10000)
print("Training finished.")
# 4. Save and load the model (optional)
model.save("ars_cartpole")
del model # remove to demonstrate loading
model = ARS.load("ars_cartpole")
# 5. Evaluate the trained agent
print("Evaluating the trained model...")
obs, info = vec_env.reset()
for _ in range(100): # Run for 100 steps
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, infos = vec_env.step(action)
# Handle episode termination for vectorized environments
for i, done in enumerate(dones):
if done:
print(f"Episode finished, reward: {infos[i]['episode']['r']:.2f}")
vec_env.close()