NumPyro

0.20.1 · active · verified Mon Apr 13

NumPyro is a probabilistic programming library that leverages JAX for automatic differentiation, JIT compilation, and GPU/TPU acceleration. It allows users to build and infer Bayesian models with a flexible and composable API inspired by Pyro. NumPyro is currently at version 0.20.1 and maintains a regular release cadence, often releasing minor versions monthly or bi-monthly with new features, bug fixes, and performance improvements.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates a basic Bayesian linear regression model using NumPyro with the NUTS sampler. It sets up a simple model, generates synthetic data, performs MCMC inference, and prints a summary of the posterior samples. It highlights proper `jax.random.PRNGKey` handling and passing data to the model.

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

# Optional: Uncomment to force CPU-only execution
# jax.config.update("jax_platform_name", "cpu")

def model(x, obs=None):
    # Prior for intercept
    a = numpyro.sample("a", dist.Normal(0, 1))
    # Prior for slope
    b = numpyro.sample("b", dist.Normal(0, 1))
    # Prior for observation noise, must be positive
    sigma = numpyro.sample("sigma", dist.HalfCauchy(1))

    # Linear model mean
    mu = a + b * x

    # Likelihood
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=obs)

# Generate some dummy data
rng_key_data, rng_key_model = jax.random.split(jax.random.PRNGKey(0))
true_a = 0.5
true_b = 2.0
true_sigma = 0.8
N_samples = 100
x_data = jax.random.normal(rng_key_data, (N_samples,))
y_data = true_a + true_b * x_data + jax.random.normal(rng_key_data, (N_samples,)) * true_sigma

# MCMC setup
kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=500,
    num_samples=1000,
    num_chains=1,
    progress_bar=False, # Set to True for interactive use
    jit_model_args=True,
)

# Run MCMC
mcmc.run(rng_key_model, x=x_data, obs=y_data)
mcmc.print_summary()

# # To get posterior samples:
# samples = mcmc.get_samples()
# # print("\nSampled parameters:", {k: v.shape for k, v in samples.items()})

view raw JSON →