JAX

0.9.2 · active · verified Thu Mar 26

Composable transformations of Python+NumPy: differentiate (jax.grad), compile (jax.jit), vectorize (jax.vmap), parallelize (jax.shard_map). Current version is 0.9.2 (Mar 2026). Requires Python >=3.11. Install requires extras — bare pip install jax gives CPU-only minimal build.

Warnings

Install

Imports

Quickstart

Core JAX pattern: compose grad, jit, vmap freely. All functions must be pure.

import jax
import jax.numpy as jnp

# grad: automatic differentiation
def loss(params, x):
    return jnp.sum((params['w'] @ x - params['b']) ** 2)

grad_loss = jax.grad(loss)  # gradient w.r.t. first arg by default

# jit: XLA compilation
@jax.jit
def fast_loss(params, x):
    return loss(params, x)

# vmap: auto-vectorize over batch dimension
batched_loss = jax.vmap(loss, in_axes=(None, 0))  # params fixed, x batched

# Compose freely:
fast_batched_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0)))

view raw JSON →