Chex

0.1.91 · active · verified Fri Apr 10

Chex is a library of utilities for helping to write reliable JAX code. It provides tools for instrumenting code (e.g., assertions, warnings), debugging (e.g., transforming `pmap`s to `vmap`s for single-device debugging), and testing JAX code across various execution contexts (e.g., JIT-compiled vs. non-JIT-compiled). The current version is 0.1.91, and it is actively maintained by Google DeepMind with frequent updates.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates defining a JAX-compatible dataclass using `chex.dataclass`, performing a JAX `tree_map` operation on it, and using `chex.assert_tree_all_finite` within a JIT-compiled function by decorating it with `chex.chexify`.

import chex
import jax
import jax.numpy as jnp

# Define a JAX-friendly dataclass
@chex.dataclass
class Parameters:
    x: chex.ArrayDevice
    y: chex.ArrayDevice

# Create an instance
params = Parameters(x=jnp.ones((2, 2)), y=jnp.ones((1, 2)))

# Dataclasses can be treated as JAX pytrees
transformed_params = jax.tree_util.tree_map(lambda val: 2.0 * val, params)
print(f"Original params: {params.x}\nTransformed params: {transformed_params.x}")

# Use an assertion
def my_func(val):
    chex.assert_tree_all_finite(val)
    return val * 2

# Assertions can be used within jitted functions with chexify
@chex.chexify
@jax.jit
def jitted_func(val):
    return my_func(val)

# This will pass
jitted_func(jnp.array([1.0, 2.0]))

# This would fail (if uncommented) because of NaN values
# try:
#     jitted_func(jnp.array([1.0, jnp.nan]))
# except chex.errors.ChexTypeError as e:
#     print(f"Caught expected error: {e}")

view raw JSON →