JAX
Composable transformations of Python+NumPy: differentiate (jax.grad), compile (jax.jit), vectorize (jax.vmap), parallelize (jax.shard_map). Current version is 0.9.2 (Mar 2026). Requires Python >=3.11. Install requires extras — bare pip install jax gives CPU-only minimal build.
Common errors
-
ModuleNotFoundError: No module named 'jax'
cause The JAX library, or its essential `jaxlib` component, is not installed or not accessible in the Python environment being used. This often happens if only `pip install jax` was run without the necessary `jaxlib` component or if there are multiple Python environments.fixEnsure both `jax` and `jaxlib` are correctly installed for your specific hardware (CPU, CUDA, ROCm) and Python version. For CPU only: `pip install --upgrade "jax[cpu]"`. For CUDA 12: `pip install --upgrade "jax[cuda12-local]"`. -
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected
cause This error occurs when a JAX Tracer object (an abstract value used during JIT compilation) is used in a context where a concrete, Python-native value is required, often within Python control flow (e.g., `if` statements, loop bounds) inside a `jax.jit` decorated function. JAX's JIT compilation requires shapes and types to be static, meaning they are known at compile time.fixMark the problematic argument as static using `static_argnums` or `static_argnames` in `jax.jit`, or refactor the code to use JAX's structured control flow primitives like `jax.lax.cond` or `jax.lax.fori_loop` instead of native Python control flow for traced values. -
AttributeError: module 'jax' has no attribute 'version'
cause This error typically indicates a mismatch between the installed `jax` and `jaxlib` versions, or that an older version of JAX is being used with code expecting a newer API. This specific attribute might also be missing due to circular import issues or a corrupted installation.fixUninstall both `jax` and `jaxlib` completely, then reinstall compatible versions. It's often best to install the latest versions together, for example: `pip uninstall jax jaxlib` followed by `pip install --upgrade "jax[cpu]"` (or the appropriate GPU/TPU variant). -
ValueError: Non-hashable static arguments are not supported
cause This error occurs when a non-hashable Python object (like a list, dictionary, or a JAX Tracer) is passed as a static argument to a `jax.jit` or `jax.vmap` decorated function, but it is not marked as static. JAX uses hashing to cache JIT-compiled functions, and static arguments must be hashable.fixEither convert the non-hashable argument to a hashable type (e.g., a tuple instead of a list), or explicitly mark it as a static argument using `static_argnums` or `static_argnames` in the `jax.jit` or `jax.vmap` decorator. For example, `partial(func, static_arg=my_non_hashable_arg)` with `jax.jit(func, static_argnums=...)`. -
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices.
cause JAX cannot detect or properly initialize a CUDA-enabled GPU. This is often due to incompatible NVIDIA drivers, CUDA Toolkit, or cuDNN versions with the installed `jaxlib` version, or incorrect environment variable settings (e.g., `LD_LIBRARY_PATH`).fixVerify that your NVIDIA driver, CUDA Toolkit, and cuDNN versions are compatible with the specific `jaxlib` wheel you've installed by consulting the official JAX installation guide. Ensure `LD_LIBRARY_PATH` and `PATH` environment variables correctly point to your CUDA installation. If using a specific CUDA version, reinstall JAX using the corresponding `jax[cudaXX-local]` or `jax[cudaXX-pip]` extra.
Warnings
- breaking bare pip install jax installs a minimal stub — no jaxlib binary. You must use an extra: pip install jax[cpu] for CPU, pip install jax[cuda12] for CUDA 12. Without an extra, import jax raises ImportError or runs extremely slowly.
- breaking jax and jaxlib versions must exactly match. Installing mismatched versions raises RuntimeError on import. The JAX team periodically yanks old jaxlib wheels from PyPI.
- breaking jax.jit() and other transforms now enforce keyword-only arguments (since 0.7). jax.jit(f, (0,)) for static_argnums raises TypeError. Was DeprecationWarning in 0.6.
- breaking jax.pmap is in maintenance mode. New code should use jax.shard_map for multi-device parallelism. pmap's default implementation is being switched to shard_map internals.
- breaking Older jaxlib wheels are periodically deleted from PyPI due to storage limits. Pinning old jaxlib versions will cause install failures in CI after deletion.
- gotcha Side effects inside jit/grad/vmap are silently dropped. Python print(), list.append(), global variable mutations only execute during the initial tracing pass — not on subsequent JIT-compiled calls. The #1 invisible bug for JAX beginners.
- gotcha JAX arrays are immutable — no in-place operations. x[0] = 1 raises TypeError. Use x.at[0].set(1) for functional updates.
- gotcha By default JAX uses 32-bit floats even when NumPy would use 64-bit. jnp.array(1.0).dtype is float32, not float64. Enable x64 mode explicitly if needed.
Install
-
pip install jax[cpu] -
pip install jax[cuda12] -
pip install jax[cuda13] -
pip install jax[tpu]
Imports
- jax.jit
# jit now requires fun as positional, all other args as keyword # jax.jit(f, (0,)) — positional static_argnums raises DeprecationWarning in 0.6, error in 0.7+
import jax import jax.numpy as jnp @jax.jit def f(x, y): return jnp.dot(x, y) # Or explicit keyword args: jax.jit(f, static_argnums=(0,)) - pure functions
# Side effects are silently dropped under jit global_list = [] @jax.jit def append_and_return(x): global_list.append(x) # silently does nothing under jit return x * 2# Pure function — same inputs always give same outputs @jax.jit def add(x, y): return x + y
Quickstart
import jax
import jax.numpy as jnp
# grad: automatic differentiation
def loss(params, x):
return jnp.sum((params['w'] @ x - params['b']) ** 2)
grad_loss = jax.grad(loss) # gradient w.r.t. first arg by default
# jit: XLA compilation
@jax.jit
def fast_loss(params, x):
return loss(params, x)
# vmap: auto-vectorize over batch dimension
batched_loss = jax.vmap(loss, in_axes=(None, 0)) # params fixed, x batched
# Compose freely:
fast_batched_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0)))