Blackjax
raw JSON → 1.5 verified Mon Apr 27 auth: no python
Blackjax is a flexible and fast Markov chain Monte Carlo (MCMC) sampling library in Python, built on JAX for GPU/TPU acceleration. Current version is 1.5, with active development and frequent releases.
pip install blackjax Common errors
error ModuleNotFoundError: No module named 'blackjax' ↓
cause Blackjax not installed or installed in an isolated environment.
fix
Run
pip install blackjax in the correct environment. error AttributeError: module 'blackjax' has no attribute 'mcmc' ↓
cause Using old import path for samplers.
fix
Directly call
blackjax.hmc(...) or blackjax.nuts(...) instead of blackjax.mcmc.hmc(...). Warnings
breaking Blackjax v1.0+ removed the old API using `blackjax.mcmc` sampler constructors (e.g., `blackjax.mcmc.hmc`). Use `blackjax.hmc` directly. ↓
fix Replace `from blackjax.mcmc import hmc` with `import blackjax; kernel = blackjax.hmc(...)`.
gotcha Blackjax requires JAX and does not support plain NumPy arrays as inputs; all functions must use JAX numpy (`jnp`) and PRNG keys. ↓
fix Always use `jax.random.PRNGKey` for randomness and `jnp.array` for data.
deprecated The `blackjax.mcmc` submodule is deprecated in favor of top-level sampler functions (e.g., `blackjax.hmc`, `blackjax.nuts`). ↓
fix Use `blackjax.hmc(...)` instead of `blackjax.mcmc.hmc(...)`.
Imports
- blackjax wrong
from blackjax import ... (not for top-level)correctimport blackjax
Quickstart
import jax
import jax.numpy as jnp
import blackjax
# Define a simple target distribution (2D Gaussian)
def logdensity_fn(x):
return -0.5 * jnp.sum(x**2)
# Build the HMC kernel
kernel = blackjax.hmc(logdensity_fn, step_size=0.1, inverse_mass_matrix=jnp.eye(2), num_integration_steps=10)
# Initialize state
key = jax.random.PRNGKey(0)
initial_position = jnp.array([1.0, 1.0])
initial_state = kernel.init(initial_position)
# Sample
key, subkey = jax.random.split(key)
state, info = kernel.step(subkey, initial_state)
print(state.position)