JAX
raw JSON → 0.9.2 verified Tue May 12 auth: no python install: stale quickstart: draft
Composable transformations of Python+NumPy: differentiate (jax.grad), compile (jax.jit), vectorize (jax.vmap), parallelize (jax.shard_map). Current version is 0.9.2 (Mar 2026). Requires Python >=3.11. Install requires extras — bare pip install jax gives CPU-only minimal build.
pip install jax[cpu] Common errors
error ModuleNotFoundError: No module named 'jax' ↓
cause The JAX library, or its essential `jaxlib` component, is not installed or not accessible in the Python environment being used. This often happens if only `pip install jax` was run without the necessary `jaxlib` component or if there are multiple Python environments.
fix
Ensure both
jax and jaxlib are correctly installed for your specific hardware (CPU, CUDA, ROCm) and Python version. For CPU only: pip install --upgrade "jax[cpu]". For CUDA 12: pip install --upgrade "jax[cuda12-local]". error ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected ↓
cause This error occurs when a JAX Tracer object (an abstract value used during JIT compilation) is used in a context where a concrete, Python-native value is required, often within Python control flow (e.g., `if` statements, loop bounds) inside a `jax.jit` decorated function. JAX's JIT compilation requires shapes and types to be static, meaning they are known at compile time.
fix
Mark the problematic argument as static using
static_argnums or static_argnames in jax.jit, or refactor the code to use JAX's structured control flow primitives like jax.lax.cond or jax.lax.fori_loop instead of native Python control flow for traced values. error AttributeError: module 'jax' has no attribute 'version' ↓
cause This error typically indicates a mismatch between the installed `jax` and `jaxlib` versions, or that an older version of JAX is being used with code expecting a newer API. This specific attribute might also be missing due to circular import issues or a corrupted installation.
fix
Uninstall both
jax and jaxlib completely, then reinstall compatible versions. It's often best to install the latest versions together, for example: pip uninstall jax jaxlib followed by pip install --upgrade "jax[cpu]" (or the appropriate GPU/TPU variant). error ValueError: Non-hashable static arguments are not supported ↓
cause This error occurs when a non-hashable Python object (like a list, dictionary, or a JAX Tracer) is passed as a static argument to a `jax.jit` or `jax.vmap` decorated function, but it is not marked as static. JAX uses hashing to cache JIT-compiled functions, and static arguments must be hashable.
fix
Either convert the non-hashable argument to a hashable type (e.g., a tuple instead of a list), or explicitly mark it as a static argument using
static_argnums or static_argnames in the jax.jit or jax.vmap decorator. For example, partial(func, static_arg=my_non_hashable_arg) with jax.jit(func, static_argnums=...). error RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. ↓
cause JAX cannot detect or properly initialize a CUDA-enabled GPU. This is often due to incompatible NVIDIA drivers, CUDA Toolkit, or cuDNN versions with the installed `jaxlib` version, or incorrect environment variable settings (e.g., `LD_LIBRARY_PATH`).
fix
Verify that your NVIDIA driver, CUDA Toolkit, and cuDNN versions are compatible with the specific
jaxlib wheel you've installed by consulting the official JAX installation guide. Ensure LD_LIBRARY_PATH and PATH environment variables correctly point to your CUDA installation. If using a specific CUDA version, reinstall JAX using the corresponding jax[cudaXX-local] or jax[cudaXX-pip] extra. Warnings
breaking bare pip install jax installs a minimal stub — no jaxlib binary. You must use an extra: pip install jax[cpu] for CPU, pip install jax[cuda12] for CUDA 12. Without an extra, import jax raises ImportError or runs extremely slowly. ↓
fix Always install with an extra: pip install 'jax[cpu]' or pip install 'jax[cuda12]'. Check https://jax.readthedocs.io/en/latest/installation.html for the current CUDA extras.
breaking jax and jaxlib versions must exactly match. Installing mismatched versions raises RuntimeError on import. The JAX team periodically yanks old jaxlib wheels from PyPI. ↓
fix Always install together: pip install 'jax[cpu]' — this installs the matching jaxlib automatically. If pinning: pin both jax==X.Y.Z and jaxlib==X.Y.Z to the same version.
breaking jax.jit() and other transforms now enforce keyword-only arguments (since 0.7). jax.jit(f, (0,)) for static_argnums raises TypeError. Was DeprecationWarning in 0.6. ↓
fix Pass all transform arguments by keyword: jax.jit(f, static_argnums=(0,)) not jax.jit(f, (0,)).
breaking jax.pmap is in maintenance mode. New code should use jax.shard_map for multi-device parallelism. pmap's default implementation is being switched to shard_map internals. ↓
fix Migrate new multi-device code to jax.shard_map. Existing pmap code will continue to work for now but will not receive new features.
breaking Older jaxlib wheels are periodically deleted from PyPI due to storage limits. Pinning old jaxlib versions will cause install failures in CI after deletion. ↓
fix Install older versions from the JAX archive index: pip install 'jax[cpu]==X.Y.Z' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
gotcha Side effects inside jit/grad/vmap are silently dropped. Python print(), list.append(), global variable mutations only execute during the initial tracing pass — not on subsequent JIT-compiled calls. The #1 invisible bug for JAX beginners. ↓
fix Use jax.debug.print() for debugging inside jit. Keep all side effects outside of transformed functions. Functions must be pure (same inputs → same outputs).
gotcha JAX arrays are immutable — no in-place operations. x[0] = 1 raises TypeError. Use x.at[0].set(1) for functional updates. ↓
fix Use the .at[].set() / .at[].add() / .at[].mul() functional update API: x = x.at[0].set(1)
gotcha By default JAX uses 32-bit floats even when NumPy would use 64-bit. jnp.array(1.0).dtype is float32, not float64. Enable x64 mode explicitly if needed. ↓
fix Enable 64-bit: jax.config.update('jax_enable_x64', True) — must be called before any JAX operations. Or use context manager: with jax.enable_x64(): ...
Install
pip install jax[cuda12] pip install jax[cuda13] pip install jax[tpu] Install compatibility stale last tested: 2026-05-12
python os / libc variant status wheel install import disk
3.10 alpine (musl) cpu - - - -
3.10 alpine (musl) cuda12 - - - -
3.10 alpine (musl) cuda13 - - - -
3.10 alpine (musl) tpu - - - -
3.10 slim (glibc) cpu - - 0.86s 584M
3.10 slim (glibc) cuda12 - - 1.15s 5.1G
3.10 slim (glibc) cuda13 - - 0.90s 584M
3.10 slim (glibc) tpu - - 0.88s 929M
3.11 alpine (musl) cpu - - - -
3.11 alpine (musl) cuda12 - - - -
3.11 alpine (musl) cuda13 - - - -
3.11 alpine (musl) tpu - - - -
3.11 slim (glibc) cpu - - 2.18s 620M
3.11 slim (glibc) cuda12 - - 3.14s 5.2G
3.11 slim (glibc) cuda13 - - 2.79s 3.9G
3.11 slim (glibc) tpu - - 2.57s 1.4G
3.12 alpine (musl) cpu - - - -
3.12 alpine (musl) cuda12 - - - -
3.12 alpine (musl) cuda13 - - - -
3.12 alpine (musl) tpu - - - -
3.12 slim (glibc) cpu - - 2.19s 605M
3.12 slim (glibc) cuda12 - - 2.86s 5.2G
3.12 slim (glibc) cuda13 - - 2.92s 3.9G
3.12 slim (glibc) tpu - - 2.37s 1.4G
3.13 alpine (musl) cpu - - - -
3.13 alpine (musl) cuda12 - - - -
3.13 alpine (musl) cuda13 - - - -
3.13 alpine (musl) tpu - - - -
3.13 slim (glibc) cpu - - 2.12s 604M
3.13 slim (glibc) cuda12 - - 2.59s 5.2G
3.13 slim (glibc) cuda13 - - 2.47s 3.9G
3.13 slim (glibc) tpu - - 2.35s 1.4G
3.9 alpine (musl) cpu - - - -
3.9 alpine (musl) cuda12 - - - -
3.9 alpine (musl) cuda13 - - - -
3.9 alpine (musl) tpu - - - -
3.9 slim (glibc) cpu - - 1.09s 555M
3.9 slim (glibc) cuda12 - - 1.44s 4.8G
3.9 slim (glibc) cuda13 - - 1.11s 555M
3.9 slim (glibc) tpu - - - -
Imports
- jax.jit wrong
# jit now requires fun as positional, all other args as keyword # jax.jit(f, (0,)) — positional static_argnums raises DeprecationWarning in 0.6, error in 0.7+correctimport jax import jax.numpy as jnp @jax.jit def f(x, y): return jnp.dot(x, y) # Or explicit keyword args: jax.jit(f, static_argnums=(0,)) - pure functions wrong
# Side effects are silently dropped under jit global_list = [] @jax.jit def append_and_return(x): global_list.append(x) # silently does nothing under jit return x * 2correct# Pure function — same inputs always give same outputs @jax.jit def add(x, y): return x + y
Quickstart draft last tested: 2026-04-23
import jax
import jax.numpy as jnp
# grad: automatic differentiation
def loss(params, x):
return jnp.sum((params['w'] @ x - params['b']) ** 2)
grad_loss = jax.grad(loss) # gradient w.r.t. first arg by default
# jit: XLA compilation
@jax.jit
def fast_loss(params, x):
return loss(params, x)
# vmap: auto-vectorize over batch dimension
batched_loss = jax.vmap(loss, in_axes=(None, 0)) # params fixed, x batched
# Compose freely:
fast_batched_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0)))