Pyro: A Python library for probabilistic modeling and inference

1.9.1 · active · verified Sat Apr 11

Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch. It enables expressive deep probabilistic modeling, unifying modern deep learning and Bayesian inference. Maintained by community contributors, including a team at the Broad Institute, Pyro is under active development with frequent releases.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates a simple Bayesian coin-tossing model using Stochastic Variational Inference (SVI). It defines a probabilistic `model`, a variational `guide`, uses synthetic data, performs inference, and extracts the learned posterior parameters for the coin's bias.

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

# Configure PyTorch for deterministic results (optional)
torch.manual_seed(1);

# 1. Define a probabilistic model
def model(data):
    # Global parameter: probability of success 'theta' for a Bernoulli distribution
    theta = pyro.sample("theta", dist.Beta(1.0, 1.0)) # Prior for theta
    # Observe data using pyro.plate for vectorized computation
    with pyro.plate("data_loop", len(data)):
        pyro.sample("obs", dist.Bernoulli(theta), obs=data)

# 2. Define a guide (variational distribution)
def guide(data):
    # Learnable parameters for the Beta distribution approximating theta
    alpha_q = pyro.param("alpha_q", torch.tensor(1.0), constraint=dist.constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(1.0), constraint=dist.constraints.positive)
    pyro.sample("theta", dist.Beta(alpha_q, beta_q))

# 3. Generate synthetic data (e.g., 8 heads, 2 tails)
data = torch.tensor([1.0]*8 + [0.0]*2)

# 4. Set up an optimizer and SVI
optimizer = Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# 5. Run inference
n_steps = 1000
for step in range(n_steps):
    loss = svi.step(data)
    if step % 100 == 0:
        print(f"Step {step}: Loss = {loss:.4f}")

# 6. Extract learned parameters
alpha_q_learned = pyro.param("alpha_q").item()
beta_q_learned = pyro.param("beta_q").item()
print(f"\nLearned parameters for theta (Beta distribution): alpha_q={alpha_q_learned:.2f}, beta_q={beta_q_learned:.2f}")

# Example: Sample from the inferred posterior
posterior_theta_samples = [guide(data).item() for _ in range(1000)]
print(f"\nMean of posterior theta samples: {torch.tensor(posterior_theta_samples).mean():.2f}")

view raw JSON →