Blackjax

raw JSON →
1.5 verified Mon Apr 27 auth: no python

Blackjax is a flexible and fast Markov chain Monte Carlo (MCMC) sampling library in Python, built on JAX for GPU/TPU acceleration. Current version is 1.5, with active development and frequent releases.

pip install blackjax
error ModuleNotFoundError: No module named 'blackjax'
cause Blackjax not installed or installed in an isolated environment.
fix
Run pip install blackjax in the correct environment.
error AttributeError: module 'blackjax' has no attribute 'mcmc'
cause Using old import path for samplers.
fix
Directly call blackjax.hmc(...) or blackjax.nuts(...) instead of blackjax.mcmc.hmc(...).
breaking Blackjax v1.0+ removed the old API using `blackjax.mcmc` sampler constructors (e.g., `blackjax.mcmc.hmc`). Use `blackjax.hmc` directly.
fix Replace `from blackjax.mcmc import hmc` with `import blackjax; kernel = blackjax.hmc(...)`.
gotcha Blackjax requires JAX and does not support plain NumPy arrays as inputs; all functions must use JAX numpy (`jnp`) and PRNG keys.
fix Always use `jax.random.PRNGKey` for randomness and `jnp.array` for data.
deprecated The `blackjax.mcmc` submodule is deprecated in favor of top-level sampler functions (e.g., `blackjax.hmc`, `blackjax.nuts`).
fix Use `blackjax.hmc(...)` instead of `blackjax.mcmc.hmc(...)`.

Minimal HMC sampling with Blackjax.

import jax
import jax.numpy as jnp
import blackjax

# Define a simple target distribution (2D Gaussian)
def logdensity_fn(x):
    return -0.5 * jnp.sum(x**2)

# Build the HMC kernel
kernel = blackjax.hmc(logdensity_fn, step_size=0.1, inverse_mass_matrix=jnp.eye(2), num_integration_steps=10)

# Initialize state
key = jax.random.PRNGKey(0)
initial_position = jnp.array([1.0, 1.0])
initial_state = kernel.init(initial_position)

# Sample
key, subkey = jax.random.split(key)
state, info = kernel.step(subkey, initial_state)
print(state.position)