{"id":5345,"library":"numpyro","title":"NumPyro","description":"NumPyro is a probabilistic programming library that leverages JAX for automatic differentiation, JIT compilation, and GPU/TPU acceleration. It allows users to build and infer Bayesian models with a flexible and composable API inspired by Pyro. NumPyro is currently at version 0.20.1 and maintains a regular release cadence, often releasing minor versions monthly or bi-monthly with new features, bug fixes, and performance improvements.","status":"active","version":"0.20.1","language":"en","source_language":"en","source_url":"https://github.com/pyro-ppl/numpyro","tags":["probabilistic programming","bayesian inference","jax","deep learning","mcmc","variational inference"],"install":[{"cmd":"pip install numpyro[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html","lang":"bash","label":"For CUDA-enabled GPU (change cuda12_pip to your CUDA version)"},{"cmd":"pip install numpyro[cpu]","lang":"bash","label":"For CPU only"}],"dependencies":[{"reason":"Core dependency for automatic differentiation, JIT compilation, and device management.","package":"jax","optional":false},{"reason":"JAX's compiled backend; must be compatible with JAX and your hardware (CPU/GPU/TPU).","package":"jaxlib","optional":false},{"reason":"Used for optimizers in some inference algorithms and contributions.","package":"optax","optional":true},{"reason":"Used internally for some advanced inference and distribution functionalities.","package":"funsor","optional":true}],"imports":[{"symbol":"numpyro","correct":"import numpyro"},{"symbol":"numpyro.distributions","correct":"import numpyro.distributions as dist"},{"symbol":"numpyro.infer","correct":"from numpyro.infer import MCMC, NUTS"},{"note":"JAX PRNGKeys are consumed and must be split before each use.","wrong":"key = jax.random.PRNGKey(0); result1 = some_func(key); result2 = another_func(key)","symbol":"jax.random.PRNGKey","correct":"import jax; key = jax.random.PRNGKey(0); key1, key2 = jax.random.split(key)"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.infer import MCMC, NUTS\n\n# Optional: Uncomment to force CPU-only execution\n# jax.config.update(\"jax_platform_name\", \"cpu\")\n\ndef model(x, obs=None):\n    # Prior for intercept\n    a = numpyro.sample(\"a\", dist.Normal(0, 1))\n    # Prior for slope\n    b = numpyro.sample(\"b\", dist.Normal(0, 1))\n    # Prior for observation noise, must be positive\n    sigma = numpyro.sample(\"sigma\", dist.HalfCauchy(1))\n\n    # Linear model mean\n    mu = a + b * x\n\n    # Likelihood\n    numpyro.sample(\"obs\", dist.Normal(mu, sigma), obs=obs)\n\n# Generate some dummy data\nrng_key_data, rng_key_model = jax.random.split(jax.random.PRNGKey(0))\ntrue_a = 0.5\ntrue_b = 2.0\ntrue_sigma = 0.8\nN_samples = 100\nx_data = jax.random.normal(rng_key_data, (N_samples,))\ny_data = true_a + true_b * x_data + jax.random.normal(rng_key_data, (N_samples,)) * true_sigma\n\n# MCMC setup\nkernel = NUTS(model)\nmcmc = MCMC(\n    kernel,\n    num_warmup=500,\n    num_samples=1000,\n    num_chains=1,\n    progress_bar=False, # Set to True for interactive use\n    jit_model_args=True,\n)\n\n# Run MCMC\nmcmc.run(rng_key_model, x=x_data, obs=y_data)\nmcmc.print_summary()\n\n# # To get posterior samples:\n# samples = mcmc.get_samples()\n# # print(\"\\nSampled parameters:\", {k: v.shape for k, v in samples.items()})\n","lang":"python","description":"This quickstart demonstrates a basic Bayesian linear regression model using NumPyro with the NUTS sampler. It sets up a simple model, generates synthetic data, performs MCMC inference, and prints a summary of the posterior samples. It highlights proper `jax.random.PRNGKey` handling and passing data to the model."},"warnings":[{"fix":"Always install `jax` and `jaxlib` using the officially recommended method for your hardware (CPU/GPU/TPU) and ensure compatibility with your `numpyro` version. Consult the JAX installation guide and `numpyro`'s dependencies.","message":"NumPyro's performance and stability are highly dependent on JAX and JAXlib versions. Incompatible versions can lead to cryptic errors or poor performance.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Use `jax.random.split(key)` to generate new, independent keys for each random operation or branch in your JAX/NumPyro code. For MCMC, a new key should be passed to `mcmc.run()`.","message":"JAX's random number generation uses a functional approach where `jax.random.PRNGKey`s are consumed upon use and must be explicitly split for subsequent operations. Reusing the same key will lead to identical 'random' results.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Ensure your models are pure functions. Avoid Python control flow (loops, conditionals) that depend on data values; use JAX's `jax.lax` primitives (e.g., `jax.lax.scan`, `jax.lax.cond`) for data-dependent logic inside JIT-compiled functions.","message":"JAX's JIT compilation (which NumPyro heavily utilizes) requires functions to be 'pure' (no side effects, deterministic output for given inputs, no global state changes). Violating this can prevent compilation or lead to incorrect results.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Refactor custom `AutoGuide` implementations to avoid relying on internal, non-public attributes. Focus on the public API for defining guides. The `sample_posterior()` signature was also unified, requiring updates to direct calls if not using the standard `mcmc.get_samples()`.","message":"In NumPyro 0.18.0, the internal caching mechanism for `plates` within `AutoGuide` was removed. This might affect users who relied on inspecting or manipulating internal `_plates` attributes of custom `AutoGuide` implementations.","severity":"breaking","affected_versions":">=0.18.0"}],"env_vars":null,"last_verified":"2026-04-13T00:00:00.000Z","next_check":"2026-07-12T00:00:00.000Z"}