{"id":21006,"library":"blackjax","title":"Blackjax","description":"Blackjax is a flexible and fast Markov chain Monte Carlo (MCMC) sampling library in Python, built on JAX for GPU/TPU acceleration. Current version is 1.5, with active development and frequent releases.","status":"active","version":"1.5","language":"python","source_language":"en","source_url":"https://github.com/blackjax-devs/blackjax","tags":["mcmc","sampling","jax","bayesian","statistics","gpu"],"install":[{"cmd":"pip install blackjax","lang":"bash","label":"Latest"}],"dependencies":[{"reason":"Core dependency for computation and autograd","package":"jax","optional":false},{"reason":"JAX library for CPU/GPU/TPU backends","package":"jaxlib","optional":false},{"reason":"Type annotations for JAX arrays","package":"jaxtyping","optional":true}],"imports":[{"note":"Top-level package; submodules like blackjax.mcmc are accessed via dot notation.","wrong":"from blackjax import ... (not for top-level)","symbol":"blackjax","correct":"import blackjax"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport blackjax\n\n# Define a simple target distribution (2D Gaussian)\ndef logdensity_fn(x):\n    return -0.5 * jnp.sum(x**2)\n\n# Build the HMC kernel\nkernel = blackjax.hmc(logdensity_fn, step_size=0.1, inverse_mass_matrix=jnp.eye(2), num_integration_steps=10)\n\n# Initialize state\nkey = jax.random.PRNGKey(0)\ninitial_position = jnp.array([1.0, 1.0])\ninitial_state = kernel.init(initial_position)\n\n# Sample\nkey, subkey = jax.random.split(key)\nstate, info = kernel.step(subkey, initial_state)\nprint(state.position)\n","lang":"python","description":"Minimal HMC sampling with Blackjax."},"warnings":[{"fix":"Replace `from blackjax.mcmc import hmc` with `import blackjax; kernel = blackjax.hmc(...)`.","message":"Blackjax v1.0+ removed the old API using `blackjax.mcmc` sampler constructors (e.g., `blackjax.mcmc.hmc`). Use `blackjax.hmc` directly.","severity":"breaking","affected_versions":"<1.0"},{"fix":"Always use `jax.random.PRNGKey` for randomness and `jnp.array` for data.","message":"Blackjax requires JAX and does not support plain NumPy arrays as inputs; all functions must use JAX numpy (`jnp`) and PRNG keys.","severity":"gotcha","affected_versions":"All"},{"fix":"Use `blackjax.hmc(...)` instead of `blackjax.mcmc.hmc(...)`.","message":"The `blackjax.mcmc` submodule is deprecated in favor of top-level sampler functions (e.g., `blackjax.hmc`, `blackjax.nuts`).","severity":"deprecated","affected_versions":">=1.0"}],"env_vars":null,"last_verified":"2026-04-27T00:00:00.000Z","next_check":"2026-07-26T00:00:00.000Z","problems":[{"fix":"Run `pip install blackjax` in the correct environment.","cause":"Blackjax not installed or installed in an isolated environment.","error":"ModuleNotFoundError: No module named 'blackjax'"},{"fix":"Directly call `blackjax.hmc(...)` or `blackjax.nuts(...)` instead of `blackjax.mcmc.hmc(...)`.","cause":"Using old import path for samplers.","error":"AttributeError: module 'blackjax' has no attribute 'mcmc'"}],"ecosystem":"pypi","meta_description":null,"install_score":null,"install_tag":null,"quickstart_score":null,"quickstart_tag":null}