{"id":2429,"library":"chex","title":"Chex","description":"Chex is a library of utilities for helping to write reliable JAX code. It provides tools for instrumenting code (e.g., assertions, warnings), debugging (e.g., transforming `pmap`s to `vmap`s for single-device debugging), and testing JAX code across various execution contexts (e.g., JIT-compiled vs. non-JIT-compiled). The current version is 0.1.91, and it is actively maintained by Google DeepMind with frequent updates.","status":"active","version":"0.1.91","language":"en","source_language":"en","source_url":"https://github.com/google-deepmind/chex","tags":["JAX","testing","debugging","machine learning","deep learning"],"install":[{"cmd":"pip install chex","lang":"bash","label":"Install latest PyPI release"}],"dependencies":[{"reason":"Runtime dependency","package":"absl-py"},{"reason":"Runtime dependency for type hints","package":"typing_extensions"},{"reason":"Core JAX dependency","package":"jax"},{"reason":"Core JAX dependency","package":"jaxlib"},{"reason":"Numerical operations","package":"numpy"},{"reason":"Functional utilities","package":"toolz"}],"imports":[{"symbol":"chex","correct":"import chex"},{"note":"Chex provides a JAX-compatible dataclass implementation.","wrong":"from dataclasses import dataclass","symbol":"dataclass","correct":"from chex import dataclass"},{"note":"Type hint for JAX array on a device.","symbol":"ArrayDevice","correct":"from chex import ArrayDevice"},{"symbol":"assert_tree_all_finite","correct":"from chex import assert_tree_all_finite"},{"symbol":"chexify","correct":"from chex import chexify"},{"symbol":"variants","correct":"from chex import variants"},{"symbol":"assert_max_traces","correct":"from chex import assert_max_traces"}],"quickstart":{"code":"import chex\nimport jax\nimport jax.numpy as jnp\n\n# Define a JAX-friendly dataclass\n@chex.dataclass\nclass Parameters:\n    x: chex.ArrayDevice\n    y: chex.ArrayDevice\n\n# Create an instance\nparams = Parameters(x=jnp.ones((2, 2)), y=jnp.ones((1, 2)))\n\n# Dataclasses can be treated as JAX pytrees\ntransformed_params = jax.tree_util.tree_map(lambda val: 2.0 * val, params)\nprint(f\"Original params: {params.x}\\nTransformed params: {transformed_params.x}\")\n\n# Use an assertion\ndef my_func(val):\n    chex.assert_tree_all_finite(val)\n    return val * 2\n\n# Assertions can be used within jitted functions with chexify\n@chex.chexify\n@jax.jit\ndef jitted_func(val):\n    return my_func(val)\n\n# This will pass\njitted_func(jnp.array([1.0, 2.0]))\n\n# This would fail (if uncommented) because of NaN values\n# try:\n#     jitted_func(jnp.array([1.0, jnp.nan]))\n# except chex.errors.ChexTypeError as e:\n#     print(f\"Caught expected error: {e}\")","lang":"python","description":"This quickstart demonstrates defining a JAX-compatible dataclass using `chex.dataclass`, performing a JAX `tree_map` operation on it, and using `chex.assert_tree_all_finite` within a JIT-compiled function by decorating it with `chex.chexify`."},"warnings":[{"fix":"Always initialize `chex.dataclass` instances using keyword arguments (e.g., `MyParams(x=1, y=2)` instead of `MyParams(1, 2)`).","message":"Chex's `mappable_dataclass` and `dataclass` implementations do not support positional arguments for construction, unlike standard Python dataclasses. Arguments must be provided as keyword arguments, similar to a dictionary constructor.","severity":"breaking","affected_versions":"All versions"},{"fix":"To explicitly check for `None`s in PyTrees, use `chex.assert_tree_no_nones()` or similar specific assertions.","message":"Chex has transitioned from relying on `dm-tree` to using JAX's native `jax.tree_util` for PyTree operations. As a result, `None` values are no longer treated as distinct leaves by `chex` tree assertions by default.","severity":"breaking","affected_versions":"Versions migrating from `dm-tree` usage (check release notes for specific version, but generally recent versions)."},{"fix":"After calling a `chexify`'d function, call `.wait_checks()` on the function object (e.g., `jitted_func.wait_checks()`) to ensure all asynchronous assertions have completed and raised any errors.","message":"When using `chex.chexify()` with JIT-compiled functions, assertions might run asynchronously. This means errors may not be raised immediately but potentially at a later line or function call. For reliable testing, especially when expecting an assertion to fail, you might need to explicitly wait for checks to complete.","severity":"gotcha","affected_versions":"All versions using `chex.chexify()` for async assertions."},{"fix":"Ensure `chex.assert_max_traces()` is applied *before* `jax.jit` if both are used, or wrap a non-jitted function directly. For example: `@jax.jit @chex.assert_max_traces(n=1) def fn(...)`.","message":"The `chex.assert_max_traces()` decorator (and similar tracing assertions like `assert_max_retraces`) expects to wrap a pure Python function, not an already JIT-compiled function. Applying it to a function that has already been decorated with `jax.jit` will likely lead to incorrect behavior or assertion failures.","severity":"gotcha","affected_versions":"All versions."}],"env_vars":null,"last_verified":"2026-04-10T00:00:00.000Z","next_check":"2026-07-09T00:00:00.000Z"}