JAX
Composable transformations of Python+NumPy: differentiate (jax.grad), compile (jax.jit), vectorize (jax.vmap), parallelize (jax.shard_map). Current version is 0.9.2 (Mar 2026). Requires Python >=3.11. Install requires extras — bare pip install jax gives CPU-only minimal build.
Warnings
- breaking bare pip install jax installs a minimal stub — no jaxlib binary. You must use an extra: pip install jax[cpu] for CPU, pip install jax[cuda12] for CUDA 12. Without an extra, import jax raises ImportError or runs extremely slowly.
- breaking jax and jaxlib versions must exactly match. Installing mismatched versions raises RuntimeError on import. The JAX team periodically yanks old jaxlib wheels from PyPI.
- breaking jax.jit() and other transforms now enforce keyword-only arguments (since 0.7). jax.jit(f, (0,)) for static_argnums raises TypeError. Was DeprecationWarning in 0.6.
- breaking jax.pmap is in maintenance mode. New code should use jax.shard_map for multi-device parallelism. pmap's default implementation is being switched to shard_map internals.
- breaking Older jaxlib wheels are periodically deleted from PyPI due to storage limits. Pinning old jaxlib versions will cause install failures in CI after deletion.
- gotcha Side effects inside jit/grad/vmap are silently dropped. Python print(), list.append(), global variable mutations only execute during the initial tracing pass — not on subsequent JIT-compiled calls. The #1 invisible bug for JAX beginners.
- gotcha JAX arrays are immutable — no in-place operations. x[0] = 1 raises TypeError. Use x.at[0].set(1) for functional updates.
- gotcha By default JAX uses 32-bit floats even when NumPy would use 64-bit. jnp.array(1.0).dtype is float32, not float64. Enable x64 mode explicitly if needed.
Install
-
pip install jax[cpu] -
pip install jax[cuda12] -
pip install jax[cuda13] -
pip install jax[tpu]
Imports
- jax.jit
import jax import jax.numpy as jnp @jax.jit def f(x, y): return jnp.dot(x, y) # Or explicit keyword args: jax.jit(f, static_argnums=(0,)) - pure functions
# Pure function — same inputs always give same outputs @jax.jit def add(x, y): return x + y
Quickstart
import jax
import jax.numpy as jnp
# grad: automatic differentiation
def loss(params, x):
return jnp.sum((params['w'] @ x - params['b']) ** 2)
grad_loss = jax.grad(loss) # gradient w.r.t. first arg by default
# jit: XLA compilation
@jax.jit
def fast_loss(params, x):
return loss(params, x)
# vmap: auto-vectorize over batch dimension
batched_loss = jax.vmap(loss, in_axes=(None, 0)) # params fixed, x batched
# Compose freely:
fast_batched_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0)))