{"id":9659,"library":"distrax","title":"Distrax","description":"Distrax is a DeepMind library offering a comprehensive collection of probability distributions and bijectors, tightly integrated with JAX for high-performance numerical computation, automatic differentiation, and GPU acceleration. It provides a flexible API for constructing complex probabilistic models and is widely used within the JAX ecosystem for research and development. The library typically follows JAX's release cadence for compatibility, with frequent updates for new features and bug fixes. Current version is 0.1.7.","status":"active","version":"0.1.7","language":"en","source_language":"en","source_url":"https://github.com/deepmind/distrax","tags":["JAX","probability","distributions","machine-learning","deepmind","probabilistic-programming"],"install":[{"cmd":"pip install distrax","lang":"bash","label":"Install distrax"}],"dependencies":[{"reason":"Core numerical backend for distributions and bijectors.","package":"jax","optional":false},{"reason":"JAX's compiled core, required for JAX operations.","package":"jaxlib","optional":false},{"reason":"Used internally and for array interoperability.","package":"numpy","optional":false}],"imports":[{"symbol":"Categorical","correct":"from distrax import Categorical"},{"symbol":"Normal","correct":"from distrax import Normal"},{"symbol":"Distribution","correct":"from distrax import Distribution"},{"symbol":"Transformed","correct":"from distrax import Transformed"},{"symbol":"Bijector","correct":"from distrax import Bijector"}],"quickstart":{"code":"import distrax\nimport jax\nimport jax.numpy as jnp\n\n# It's good practice to provide a key for reproducibility\nkey = jax.random.PRNGKey(0)\n\n# Create a Categorical distribution\nprobs = jnp.array([0.1, 0.2, 0.7])\ncategorical = distrax.Categorical(probs=probs)\n\n# Sample from it (requires a JAX PRNG key)\nsample = categorical.sample(seed=key)\nprint(f\"Categorical sample: {sample}\")\n\n# Compute log-probability\nlog_prob = categorical.log_prob(sample)\nprint(f\"Categorical log-prob: {log_prob}\")\n\n# Create a Normal distribution\nloc = jnp.array(0.0)\nscale = jnp.array(1.0)\nnormal = distrax.Normal(loc=loc, scale=scale)\n\n# Sample from it (requires a JAX PRNG key, can specify sample_shape)\nsample_normal = normal.sample(seed=key, sample_shape=(5,))\nprint(f\"Normal samples: {sample_normal}\")\n\n# Compute log-probability for a specific value\nlog_prob_normal = normal.log_prob(jnp.array(0.5))\nprint(f\"Normal log-prob of 0.5: {log_prob_normal}\")","lang":"python","description":"This example demonstrates how to define common distributions like Categorical and Normal, sample from them, and compute their log-probabilities using Distrax and JAX. It highlights the requirement for JAX PRNG keys for sampling."},"warnings":[{"fix":"Ensure `shift` and `scale` are explicitly `jax.numpy.array` scalars or shaped arrays, rather than Python floats, if you encounter dimension-related errors.","message":"The `ScalarAffine` bijector in version 0.1.3 changed its expectation for `shift` and `scale` parameters. They now explicitly expect `Array`s that are scalars or broadcast correctly to event dimensions.","severity":"breaking","affected_versions":">=0.1.3"},{"fix":"Refactor code to use `distribution.batch_shape.transpose_event_axes` or `distribution.batch_shape.expand_event_dims` for manipulating batch and event dimensions, as appropriate.","message":"The `BatchReinterpreted` distribution was deprecated in version 0.1.5 and subsequently removed. Attempting to use it in newer versions will result in an `AttributeError`.","severity":"deprecated","affected_versions":">=0.1.5 (removal)"},{"fix":"Always generate a JAX PRNG key (e.g., `key = jax.random.PRNGKey(0)`) and pass it as `seed=key` to sampling methods.","message":"All sampling methods (`sample`, `sample_and_log_prob`) require a JAX PRNG key to be passed via the `seed` argument. Forgetting this will raise an error.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Convert distribution parameters to JAX arrays explicitly, e.g., `loc=jnp.array(0.0)`.","message":"Distribution parameters (e.g., `loc`, `scale`, `probs`) should ideally be JAX arrays (`jax.numpy.array`). Passing standard Python floats or integers might sometimes work due to JAX's auto-conversion, but explicit conversion is recommended to prevent `TypeError` or unexpected broadcasting issues.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-17T00:00:00.000Z","next_check":"2026-07-16T00:00:00.000Z","problems":[{"fix":"Convert parameters to `jax.numpy.array` explicitly, e.g., `loc = jnp.array(0.0)`.","cause":"Passing a standard Python float or integer directly to a distribution parameter instead of a JAX array.","error":"TypeError: Invalid type for distribution parameter. Expected `jax.Array` or a type convertible to `jax.Array`, but got `float`."},{"fix":"Generate a JAX PRNG key (`key = jax.random.PRNGKey(0)`) and pass it as `seed=key` to the sampling method.","cause":"Calling `distribution.sample()` or `sample_and_log_prob()` without providing a JAX PRNG key via the `seed` argument.","error":"ValueError: sample requires a PRNG key."},{"fix":"Refactor code to use `distribution.batch_shape.transpose_event_axes` or `distribution.batch_shape.expand_event_dims` to manipulate batch and event dimensions.","cause":"Attempting to use the `BatchReinterpreted` distribution, which was removed in Distrax version 0.1.5.","error":"AttributeError: module 'distrax' has no attribute 'BatchReinterpreted'"}]}