NumPyro
NumPyro is a probabilistic programming library that leverages JAX for automatic differentiation, JIT compilation, and GPU/TPU acceleration. It allows users to build and infer Bayesian models with a flexible and composable API inspired by Pyro. NumPyro is currently at version 0.20.1 and maintains a regular release cadence, often releasing minor versions monthly or bi-monthly with new features, bug fixes, and performance improvements.
Warnings
- gotcha NumPyro's performance and stability are highly dependent on JAX and JAXlib versions. Incompatible versions can lead to cryptic errors or poor performance.
- gotcha JAX's random number generation uses a functional approach where `jax.random.PRNGKey`s are consumed upon use and must be explicitly split for subsequent operations. Reusing the same key will lead to identical 'random' results.
- gotcha JAX's JIT compilation (which NumPyro heavily utilizes) requires functions to be 'pure' (no side effects, deterministic output for given inputs, no global state changes). Violating this can prevent compilation or lead to incorrect results.
- breaking In NumPyro 0.18.0, the internal caching mechanism for `plates` within `AutoGuide` was removed. This might affect users who relied on inspecting or manipulating internal `_plates` attributes of custom `AutoGuide` implementations.
Install
-
pip install numpyro[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -
pip install numpyro[cpu]
Imports
- numpyro
import numpyro
- numpyro.distributions
import numpyro.distributions as dist
- numpyro.infer
from numpyro.infer import MCMC, NUTS
- jax.random.PRNGKey
import jax; key = jax.random.PRNGKey(0); key1, key2 = jax.random.split(key)
Quickstart
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
# Optional: Uncomment to force CPU-only execution
# jax.config.update("jax_platform_name", "cpu")
def model(x, obs=None):
# Prior for intercept
a = numpyro.sample("a", dist.Normal(0, 1))
# Prior for slope
b = numpyro.sample("b", dist.Normal(0, 1))
# Prior for observation noise, must be positive
sigma = numpyro.sample("sigma", dist.HalfCauchy(1))
# Linear model mean
mu = a + b * x
# Likelihood
numpyro.sample("obs", dist.Normal(mu, sigma), obs=obs)
# Generate some dummy data
rng_key_data, rng_key_model = jax.random.split(jax.random.PRNGKey(0))
true_a = 0.5
true_b = 2.0
true_sigma = 0.8
N_samples = 100
x_data = jax.random.normal(rng_key_data, (N_samples,))
y_data = true_a + true_b * x_data + jax.random.normal(rng_key_data, (N_samples,)) * true_sigma
# MCMC setup
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=500,
num_samples=1000,
num_chains=1,
progress_bar=False, # Set to True for interactive use
jit_model_args=True,
)
# Run MCMC
mcmc.run(rng_key_model, x=x_data, obs=y_data)
mcmc.print_summary()
# # To get posterior samples:
# samples = mcmc.get_samples()
# # print("\nSampled parameters:", {k: v.shape for k, v in samples.items()})