Chex
Chex is a library of utilities for helping to write reliable JAX code. It provides tools for instrumenting code (e.g., assertions, warnings), debugging (e.g., transforming `pmap`s to `vmap`s for single-device debugging), and testing JAX code across various execution contexts (e.g., JIT-compiled vs. non-JIT-compiled). The current version is 0.1.91, and it is actively maintained by Google DeepMind with frequent updates.
Warnings
- breaking Chex's `mappable_dataclass` and `dataclass` implementations do not support positional arguments for construction, unlike standard Python dataclasses. Arguments must be provided as keyword arguments, similar to a dictionary constructor.
- breaking Chex has transitioned from relying on `dm-tree` to using JAX's native `jax.tree_util` for PyTree operations. As a result, `None` values are no longer treated as distinct leaves by `chex` tree assertions by default.
- gotcha When using `chex.chexify()` with JIT-compiled functions, assertions might run asynchronously. This means errors may not be raised immediately but potentially at a later line or function call. For reliable testing, especially when expecting an assertion to fail, you might need to explicitly wait for checks to complete.
- gotcha The `chex.assert_max_traces()` decorator (and similar tracing assertions like `assert_max_retraces`) expects to wrap a pure Python function, not an already JIT-compiled function. Applying it to a function that has already been decorated with `jax.jit` will likely lead to incorrect behavior or assertion failures.
Install
-
pip install chex
Imports
- chex
import chex
- dataclass
from chex import dataclass
- ArrayDevice
from chex import ArrayDevice
- assert_tree_all_finite
from chex import assert_tree_all_finite
- chexify
from chex import chexify
- variants
from chex import variants
- assert_max_traces
from chex import assert_max_traces
Quickstart
import chex
import jax
import jax.numpy as jnp
# Define a JAX-friendly dataclass
@chex.dataclass
class Parameters:
x: chex.ArrayDevice
y: chex.ArrayDevice
# Create an instance
params = Parameters(x=jnp.ones((2, 2)), y=jnp.ones((1, 2)))
# Dataclasses can be treated as JAX pytrees
transformed_params = jax.tree_util.tree_map(lambda val: 2.0 * val, params)
print(f"Original params: {params.x}\nTransformed params: {transformed_params.x}")
# Use an assertion
def my_func(val):
chex.assert_tree_all_finite(val)
return val * 2
# Assertions can be used within jitted functions with chexify
@chex.chexify
@jax.jit
def jitted_func(val):
return my_func(val)
# This will pass
jitted_func(jnp.array([1.0, 2.0]))
# This would fail (if uncommented) because of NaN values
# try:
# jitted_func(jnp.array([1.0, jnp.nan]))
# except chex.errors.ChexTypeError as e:
# print(f"Caught expected error: {e}")