Distrax

0.1.7 · active · verified Fri Apr 17

Distrax is a DeepMind library offering a comprehensive collection of probability distributions and bijectors, tightly integrated with JAX for high-performance numerical computation, automatic differentiation, and GPU acceleration. It provides a flexible API for constructing complex probabilistic models and is widely used within the JAX ecosystem for research and development. The library typically follows JAX's release cadence for compatibility, with frequent updates for new features and bug fixes. Current version is 0.1.7.

Common errors

Warnings

Install

Imports

Quickstart

This example demonstrates how to define common distributions like Categorical and Normal, sample from them, and compute their log-probabilities using Distrax and JAX. It highlights the requirement for JAX PRNG keys for sampling.

import distrax
import jax
import jax.numpy as jnp

# It's good practice to provide a key for reproducibility
key = jax.random.PRNGKey(0)

# Create a Categorical distribution
probs = jnp.array([0.1, 0.2, 0.7])
categorical = distrax.Categorical(probs=probs)

# Sample from it (requires a JAX PRNG key)
sample = categorical.sample(seed=key)
print(f"Categorical sample: {sample}")

# Compute log-probability
log_prob = categorical.log_prob(sample)
print(f"Categorical log-prob: {log_prob}")

# Create a Normal distribution
loc = jnp.array(0.0)
scale = jnp.array(1.0)
normal = distrax.Normal(loc=loc, scale=scale)

# Sample from it (requires a JAX PRNG key, can specify sample_shape)
sample_normal = normal.sample(seed=key, sample_shape=(5,))
print(f"Normal samples: {sample_normal}")

# Compute log-probability for a specific value
log_prob_normal = normal.log_prob(jnp.array(0.5))
print(f"Normal log-prob of 0.5: {log_prob_normal}")

view raw JSON →