Stable Baselines3
Stable Baselines3 (SB3) is a comprehensive Python library offering reliable implementations of reinforcement learning (RL) algorithms in PyTorch. It provides a clean and simple API, adhering to a scikit-learn-like syntax for training, evaluating, and deploying RL agents. SB3 is actively maintained with frequent releases, supporting state-of-the-art model-free RL algorithms like A2C, PPO, SAC, DQN, and TD3.
Warnings
- breaking Dropped Python 3.9 support in v2.8.0. Users on Python 3.9 must upgrade to Python >= 3.10. Similarly, Python 3.8 support was removed in v2.5.0/v2.4.0.
- breaking The minimum required PyTorch version increased to 2.3.0 in Stable Baselines3 v2.5.0. Ensure your PyTorch installation meets this requirement.
- breaking Stable Baselines3 switched to Gymnasium as its primary environment backend starting from v2.0.0. While compatibility layers exist via `shimmy` for older `gym` environments, direct migration to Gymnasium is highly recommended.
- breaking Stable Baselines3 v2.3.0 introduced a breaking change where `torch.load()` was called with `weights_only=True`, causing issues when loading policies trained with PyTorch 1.13. This was reverted in v2.3.2.
- breaking Starting from v2.8.0, `strict=True` is now set for every call to `zip(...)` internally, which can raise `ValueError` if iterables have different lengths. This change also applies to `sb3_contrib` (v2.6.0).
- gotcha When using custom callbacks, ensure they return a boolean (`True` to continue, `False` to stop training). Returning `None` will be interpreted as `False` and abruptly stop training since `stable-baselines3-contrib` v2.6.0 (which impacts SB3).
- gotcha For accurate evaluation results, especially when other wrappers modify rewards or episode lengths (e.g., reward scaling), it is recommended to wrap your environment with the `Monitor` wrapper before any other wrappers.
Install
-
pip install stable-baselines3 gymnasium -
pip install stable-baselines3[extra] gymnasium
Imports
- PPO
from stable_baselines3 import PPO
- A2C
from stable_baselines3 import A2C
- SAC
from stable_baselines3 import SAC
- DQN
from stable_baselines3 import DQN
- make_vec_env
from stable_baselines3.common.env_util import make_vec_env
- evaluate_policy
from stable_baselines3.common.evaluation import evaluate_policy
Quickstart
import gymnasium as gym
from stable_baselines3 import A2C
# Create environment
env = gym.make("CartPole-v1")
# Instantiate the agent
model = A2C("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=10000)
# Save the model
model.save("a2c_cartpole")
# Delete model and reload it to demonstrate saving and loading
del model
model = A2C.load("a2c_cartpole")
# Evaluate the trained agent
obs, info = env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()
env.close()