Distrax
Distrax is a DeepMind library offering a comprehensive collection of probability distributions and bijectors, tightly integrated with JAX for high-performance numerical computation, automatic differentiation, and GPU acceleration. It provides a flexible API for constructing complex probabilistic models and is widely used within the JAX ecosystem for research and development. The library typically follows JAX's release cadence for compatibility, with frequent updates for new features and bug fixes. Current version is 0.1.7.
Common errors
-
TypeError: Invalid type for distribution parameter. Expected `jax.Array` or a type convertible to `jax.Array`, but got `float`.
cause Passing a standard Python float or integer directly to a distribution parameter instead of a JAX array.fixConvert parameters to `jax.numpy.array` explicitly, e.g., `loc = jnp.array(0.0)`. -
ValueError: sample requires a PRNG key.
cause Calling `distribution.sample()` or `sample_and_log_prob()` without providing a JAX PRNG key via the `seed` argument.fixGenerate a JAX PRNG key (`key = jax.random.PRNGKey(0)`) and pass it as `seed=key` to the sampling method. -
AttributeError: module 'distrax' has no attribute 'BatchReinterpreted'
cause Attempting to use the `BatchReinterpreted` distribution, which was removed in Distrax version 0.1.5.fixRefactor code to use `distribution.batch_shape.transpose_event_axes` or `distribution.batch_shape.expand_event_dims` to manipulate batch and event dimensions.
Warnings
- breaking The `ScalarAffine` bijector in version 0.1.3 changed its expectation for `shift` and `scale` parameters. They now explicitly expect `Array`s that are scalars or broadcast correctly to event dimensions.
- deprecated The `BatchReinterpreted` distribution was deprecated in version 0.1.5 and subsequently removed. Attempting to use it in newer versions will result in an `AttributeError`.
- gotcha All sampling methods (`sample`, `sample_and_log_prob`) require a JAX PRNG key to be passed via the `seed` argument. Forgetting this will raise an error.
- gotcha Distribution parameters (e.g., `loc`, `scale`, `probs`) should ideally be JAX arrays (`jax.numpy.array`). Passing standard Python floats or integers might sometimes work due to JAX's auto-conversion, but explicit conversion is recommended to prevent `TypeError` or unexpected broadcasting issues.
Install
-
pip install distrax
Imports
- Categorical
from distrax import Categorical
- Normal
from distrax import Normal
- Distribution
from distrax import Distribution
- Transformed
from distrax import Transformed
- Bijector
from distrax import Bijector
Quickstart
import distrax
import jax
import jax.numpy as jnp
# It's good practice to provide a key for reproducibility
key = jax.random.PRNGKey(0)
# Create a Categorical distribution
probs = jnp.array([0.1, 0.2, 0.7])
categorical = distrax.Categorical(probs=probs)
# Sample from it (requires a JAX PRNG key)
sample = categorical.sample(seed=key)
print(f"Categorical sample: {sample}")
# Compute log-probability
log_prob = categorical.log_prob(sample)
print(f"Categorical log-prob: {log_prob}")
# Create a Normal distribution
loc = jnp.array(0.0)
scale = jnp.array(1.0)
normal = distrax.Normal(loc=loc, scale=scale)
# Sample from it (requires a JAX PRNG key, can specify sample_shape)
sample_normal = normal.sample(seed=key, sample_shape=(5,))
print(f"Normal samples: {sample_normal}")
# Compute log-probability for a specific value
log_prob_normal = normal.log_prob(jnp.array(0.5))
print(f"Normal log-prob of 0.5: {log_prob_normal}")