Pyro: A Python library for probabilistic modeling and inference
Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch. It enables expressive deep probabilistic modeling, unifying modern deep learning and Bayesian inference. Maintained by community contributors, including a team at the Broad Institute, Pyro is under active development with frequent releases.
Warnings
- breaking Pyro 1.9.0 dropped support for PyTorch 1.x and Python 3.7. Users on older PyTorch or Python versions must upgrade to PyTorch 2.x and Python 3.8+ to use Pyro 1.9.0 and newer.
- breaking Pyro 1.8.1 dropped support for Python 3.6. Users on Python 3.6 must upgrade their Python environment to 3.7 or newer to use Pyro 1.8.1 and subsequent versions.
- gotcha Pyro's compatibility with PyTorch versions can be nuanced and has changed across minor releases. For example, 1.8.5 narrowly required `torch>=2.0`, while 1.8.6 re-enabled support for `torch>=1.11` before 1.9.0 definitively dropped PyTorch 1.x. Always check release notes for specific PyTorch version requirements.
- gotcha When defining models with conditionally independent random variables, avoid explicit Python loops and instead use `pyro.plate` for efficient, vectorized computation, especially with large datasets. Loops can be significantly slower and prevent Pyro's internal optimizations.
- gotcha Markov Chain Monte Carlo (MCMC) algorithms like NUTS (No-U-Turn Sampler) require a 'warm-up' phase. Neglecting or misconfiguring `warmup_steps` can lead to unstable chains and biased posterior samples. The warmup samples are discarded and not used for inference.
Install
-
pip install torch pip install pyro-ppl
Imports
- pyro
import pyro
- pyro.distributions
import pyro.distributions as dist
- pyro.infer
from pyro.infer import SVI, Trace_ELBO, MCMC, NUTS
- pyro.optim
from pyro.optim import Adam
Quickstart
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
# Configure PyTorch for deterministic results (optional)
torch.manual_seed(1);
# 1. Define a probabilistic model
def model(data):
# Global parameter: probability of success 'theta' for a Bernoulli distribution
theta = pyro.sample("theta", dist.Beta(1.0, 1.0)) # Prior for theta
# Observe data using pyro.plate for vectorized computation
with pyro.plate("data_loop", len(data)):
pyro.sample("obs", dist.Bernoulli(theta), obs=data)
# 2. Define a guide (variational distribution)
def guide(data):
# Learnable parameters for the Beta distribution approximating theta
alpha_q = pyro.param("alpha_q", torch.tensor(1.0), constraint=dist.constraints.positive)
beta_q = pyro.param("beta_q", torch.tensor(1.0), constraint=dist.constraints.positive)
pyro.sample("theta", dist.Beta(alpha_q, beta_q))
# 3. Generate synthetic data (e.g., 8 heads, 2 tails)
data = torch.tensor([1.0]*8 + [0.0]*2)
# 4. Set up an optimizer and SVI
optimizer = Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
# 5. Run inference
n_steps = 1000
for step in range(n_steps):
loss = svi.step(data)
if step % 100 == 0:
print(f"Step {step}: Loss = {loss:.4f}")
# 6. Extract learned parameters
alpha_q_learned = pyro.param("alpha_q").item()
beta_q_learned = pyro.param("beta_q").item()
print(f"\nLearned parameters for theta (Beta distribution): alpha_q={alpha_q_learned:.2f}, beta_q={beta_q_learned:.2f}")
# Example: Sample from the inferred posterior
posterior_theta_samples = [guide(data).item() for _ in range(1000)]
print(f"\nMean of posterior theta samples: {torch.tensor(posterior_theta_samples).mean():.2f}")