{"id":213,"library":"jax","title":"JAX","description":"Composable transformations of Python+NumPy: differentiate (jax.grad), compile (jax.jit), vectorize (jax.vmap), parallelize (jax.shard_map). Current version is 0.9.2 (Mar 2026). Requires Python >=3.11. Install requires extras — bare pip install jax gives CPU-only minimal build.","status":"active","version":"0.9.2","language":"python","source_language":"en","source_url":"https://docs.jax.dev/en/latest/changelog.html","tags":["deep-learning","ml","gpu","tpu","autodiff","jit","numpy","xla"],"install":[{"cmd":"pip install jax[cpu]","lang":"bash","label":"CPU (recommended default)"},{"cmd":"pip install jax[cuda12]","lang":"bash","label":"CUDA 12 GPU"},{"cmd":"pip install jax[cuda13]","lang":"bash","label":"CUDA 13 GPU"},{"cmd":"pip install jax[tpu]","lang":"bash","label":"Google TPU"}],"dependencies":[{"reason":"Binary backend (XLA, PJRT). Must exactly match jax version. Installed automatically with extras.","package":"jaxlib==0.9.2","optional":false},{"reason":"Required. Minimum NumPy version raised to 2.0 in recent releases.","package":"numpy>=2.0","optional":false}],"imports":[{"note":"Since 0.7, jax.jit() requires fun to be passed by position and all other arguments (static_argnums, static_argnames, etc.) by keyword. Positional use of these args now raises an error.","wrong":"# jit now requires fun as positional, all other args as keyword\n# jax.jit(f, (0,))  — positional static_argnums raises DeprecationWarning in 0.6, error in 0.7+","symbol":"jax.jit","correct":"import jax\nimport jax.numpy as jnp\n\n@jax.jit\ndef f(x, y):\n    return jnp.dot(x, y)\n\n# Or explicit keyword args:\njax.jit(f, static_argnums=(0,))"},{"note":"JAX requires pure functions. Side effects (Python print, list append, global mutation) inside jit/grad/vmap are silently dropped or only execute during tracing. This is the #1 silent bug for JAX beginners.","wrong":"# Side effects are silently dropped under jit\nglobal_list = []\n\n@jax.jit\ndef append_and_return(x):\n    global_list.append(x)  # silently does nothing under jit\n    return x * 2","symbol":"pure functions","correct":"# Pure function — same inputs always give same outputs\n@jax.jit\ndef add(x, y):\n    return x + y"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\n\n# grad: automatic differentiation\ndef loss(params, x):\n    return jnp.sum((params['w'] @ x - params['b']) ** 2)\n\ngrad_loss = jax.grad(loss)  # gradient w.r.t. first arg by default\n\n# jit: XLA compilation\n@jax.jit\ndef fast_loss(params, x):\n    return loss(params, x)\n\n# vmap: auto-vectorize over batch dimension\nbatched_loss = jax.vmap(loss, in_axes=(None, 0))  # params fixed, x batched\n\n# Compose freely:\nfast_batched_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0)))","lang":"python","description":"Core JAX pattern: compose grad, jit, vmap freely. All functions must be pure."},"warnings":[{"fix":"Always install with an extra: pip install 'jax[cpu]' or pip install 'jax[cuda12]'. Check https://jax.readthedocs.io/en/latest/installation.html for the current CUDA extras.","message":"bare pip install jax installs a minimal stub — no jaxlib binary. You must use an extra: pip install jax[cpu] for CPU, pip install jax[cuda12] for CUDA 12. Without an extra, import jax raises ImportError or runs extremely slowly.","severity":"breaking","affected_versions":"all"},{"fix":"Always install together: pip install 'jax[cpu]' — this installs the matching jaxlib automatically. If pinning: pin both jax==X.Y.Z and jaxlib==X.Y.Z to the same version.","message":"jax and jaxlib versions must exactly match. Installing mismatched versions raises RuntimeError on import. The JAX team periodically yanks old jaxlib wheels from PyPI.","severity":"breaking","affected_versions":"all"},{"fix":"Pass all transform arguments by keyword: jax.jit(f, static_argnums=(0,)) not jax.jit(f, (0,)).","message":"jax.jit() and other transforms now enforce keyword-only arguments (since 0.7). jax.jit(f, (0,)) for static_argnums raises TypeError. Was DeprecationWarning in 0.6.","severity":"breaking","affected_versions":">= 0.7"},{"fix":"Migrate new multi-device code to jax.shard_map. Existing pmap code will continue to work for now but will not receive new features.","message":"jax.pmap is in maintenance mode. New code should use jax.shard_map for multi-device parallelism. pmap's default implementation is being switched to shard_map internals.","severity":"breaking","affected_versions":">= 0.8"},{"fix":"Install older versions from the JAX archive index: pip install 'jax[cpu]==X.Y.Z' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/","message":"Older jaxlib wheels are periodically deleted from PyPI due to storage limits. Pinning old jaxlib versions will cause install failures in CI after deletion.","severity":"breaking","affected_versions":"all"},{"fix":"Use jax.debug.print() for debugging inside jit. Keep all side effects outside of transformed functions. Functions must be pure (same inputs → same outputs).","message":"Side effects inside jit/grad/vmap are silently dropped. Python print(), list.append(), global variable mutations only execute during the initial tracing pass — not on subsequent JIT-compiled calls. The #1 invisible bug for JAX beginners.","severity":"gotcha","affected_versions":"all"},{"fix":"Use the .at[].set() / .at[].add() / .at[].mul() functional update API: x = x.at[0].set(1)","message":"JAX arrays are immutable — no in-place operations. x[0] = 1 raises TypeError. Use x.at[0].set(1) for functional updates.","severity":"gotcha","affected_versions":"all"},{"fix":"Enable 64-bit: jax.config.update('jax_enable_x64', True) — must be called before any JAX operations. Or use context manager: with jax.enable_x64(): ...","message":"By default JAX uses 32-bit floats even when NumPy would use 64-bit. jnp.array(1.0).dtype is float32, not float64. Enable x64 mode explicitly if needed.","severity":"gotcha","affected_versions":"all"}],"env_vars":null,"last_verified":"2026-05-12T11:01:55.954Z","next_check":"2026-06-26T00:00:00.000Z","problems":[{"fix":"Ensure both `jax` and `jaxlib` are correctly installed for your specific hardware (CPU, CUDA, ROCm) and Python version. For CPU only: `pip install --upgrade \"jax[cpu]\"`. For CUDA 12: `pip install --upgrade \"jax[cuda12-local]\"`.","cause":"The JAX library, or its essential `jaxlib` component, is not installed or not accessible in the Python environment being used. This often happens if only `pip install jax` was run without the necessary `jaxlib` component or if there are multiple Python environments.","error":"ModuleNotFoundError: No module named 'jax'"},{"fix":"Mark the problematic argument as static using `static_argnums` or `static_argnames` in `jax.jit`, or refactor the code to use JAX's structured control flow primitives like `jax.lax.cond` or `jax.lax.fori_loop` instead of native Python control flow for traced values.","cause":"This error occurs when a JAX Tracer object (an abstract value used during JIT compilation) is used in a context where a concrete, Python-native value is required, often within Python control flow (e.g., `if` statements, loop bounds) inside a `jax.jit` decorated function. JAX's JIT compilation requires shapes and types to be static, meaning they are known at compile time.","error":"ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected"},{"fix":"Uninstall both `jax` and `jaxlib` completely, then reinstall compatible versions. It's often best to install the latest versions together, for example: `pip uninstall jax jaxlib` followed by `pip install --upgrade \"jax[cpu]\"` (or the appropriate GPU/TPU variant).","cause":"This error typically indicates a mismatch between the installed `jax` and `jaxlib` versions, or that an older version of JAX is being used with code expecting a newer API. This specific attribute might also be missing due to circular import issues or a corrupted installation.","error":"AttributeError: module 'jax' has no attribute 'version'"},{"fix":"Either convert the non-hashable argument to a hashable type (e.g., a tuple instead of a list), or explicitly mark it as a static argument using `static_argnums` or `static_argnames` in the `jax.jit` or `jax.vmap` decorator. For example, `partial(func, static_arg=my_non_hashable_arg)` with `jax.jit(func, static_argnums=...)`.","cause":"This error occurs when a non-hashable Python object (like a list, dictionary, or a JAX Tracer) is passed as a static argument to a `jax.jit` or `jax.vmap` decorated function, but it is not marked as static. JAX uses hashing to cache JIT-compiled functions, and static arguments must be hashable.","error":"ValueError: Non-hashable static arguments are not supported"},{"fix":"Verify that your NVIDIA driver, CUDA Toolkit, and cuDNN versions are compatible with the specific `jaxlib` wheel you've installed by consulting the official JAX installation guide. Ensure `LD_LIBRARY_PATH` and `PATH` environment variables correctly point to your CUDA installation. If using a specific CUDA version, reinstall JAX using the corresponding `jax[cudaXX-local]` or `jax[cudaXX-pip]` extra.","cause":"JAX cannot detect or properly initialize a CUDA-enabled GPU. This is often due to incompatible NVIDIA drivers, CUDA Toolkit, or cuDNN versions with the installed `jaxlib` version, or incorrect environment variable settings (e.g., `LD_LIBRARY_PATH`).","error":"RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices."}],"ecosystem":"pypi","meta_description":null,"install_score":0,"install_tag":"stale","quickstart_score":30,"quickstart_tag":"draft","pypi_latest":null,"install_checks":{"last_tested":"2026-05-12","tag":"stale","tag_description":"widespread failures or data too old to trust","results":[{"runtime":"python:3.10-alpine","python_version":"3.10","os_libc":"alpine (musl)","variant":"cpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.10-alpine","python_version":"3.10","os_libc":"alpine (musl)","variant":"cuda12","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.10-alpine","python_version":"3.10","os_libc":"alpine (musl)","variant":"cuda13","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.10-alpine","python_version":"3.10","os_libc":"alpine (musl)","variant":"tpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.10-slim","python_version":"3.10","os_libc":"slim (glibc)","variant":"cpu","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":0.86,"mem_mb":31.3,"disk_size":"584M"},{"runtime":"python:3.10-slim","python_version":"3.10","os_libc":"slim (glibc)","variant":"cuda12","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":1.15,"mem_mb":31.4,"disk_size":"5.1G"},{"runtime":"python:3.10-slim","python_version":"3.10","os_libc":"slim (glibc)","variant":"cuda13","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":0.9,"mem_mb":31.3,"disk_size":"584M"},{"runtime":"python:3.10-slim","python_version":"3.10","os_libc":"slim (glibc)","variant":"tpu","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":0.88,"mem_mb":31.9,"disk_size":"929M"},{"runtime":"python:3.11-alpine","python_version":"3.11","os_libc":"alpine (musl)","variant":"cpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.11-alpine","python_version":"3.11","os_libc":"alpine (musl)","variant":"cuda12","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.11-alpine","python_version":"3.11","os_libc":"alpine (musl)","variant":"cuda13","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.11-alpine","python_version":"3.11","os_libc":"alpine (musl)","variant":"tpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.11-slim","python_version":"3.11","os_libc":"slim (glibc)","variant":"cpu","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.18,"mem_mb":39.6,"disk_size":"620M"},{"runtime":"python:3.11-slim","python_version":"3.11","os_libc":"slim (glibc)","variant":"cuda12","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":3.14,"mem_mb":39.9,"disk_size":"5.2G"},{"runtime":"python:3.11-slim","python_version":"3.11","os_libc":"slim (glibc)","variant":"cuda13","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.79,"mem_mb":39.9,"disk_size":"3.9G"},{"runtime":"python:3.11-slim","python_version":"3.11","os_libc":"slim (glibc)","variant":"tpu","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.57,"mem_mb":39.8,"disk_size":"1.4G"},{"runtime":"python:3.12-alpine","python_version":"3.12","os_libc":"alpine (musl)","variant":"cpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.12-alpine","python_version":"3.12","os_libc":"alpine (musl)","variant":"cuda12","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.12-alpine","python_version":"3.12","os_libc":"alpine (musl)","variant":"cuda13","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.12-alpine","python_version":"3.12","os_libc":"alpine (musl)","variant":"tpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.12-slim","python_version":"3.12","os_libc":"slim (glibc)","variant":"cpu","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.19,"mem_mb":38.5,"disk_size":"605M"},{"runtime":"python:3.12-slim","python_version":"3.12","os_libc":"slim (glibc)","variant":"cuda12","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.86,"mem_mb":39.1,"disk_size":"5.2G"},{"runtime":"python:3.12-slim","python_version":"3.12","os_libc":"slim (glibc)","variant":"cuda13","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.92,"mem_mb":39.1,"disk_size":"3.9G"},{"runtime":"python:3.12-slim","python_version":"3.12","os_libc":"slim (glibc)","variant":"tpu","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.37,"mem_mb":39,"disk_size":"1.4G"},{"runtime":"python:3.13-alpine","python_version":"3.13","os_libc":"alpine (musl)","variant":"cpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.13-alpine","python_version":"3.13","os_libc":"alpine (musl)","variant":"cuda12","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.13-alpine","python_version":"3.13","os_libc":"alpine (musl)","variant":"cuda13","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.13-alpine","python_version":"3.13","os_libc":"alpine (musl)","variant":"tpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.13-slim","python_version":"3.13","os_libc":"slim (glibc)","variant":"cpu","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.12,"mem_mb":39.7,"disk_size":"604M"},{"runtime":"python:3.13-slim","python_version":"3.13","os_libc":"slim (glibc)","variant":"cuda12","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.59,"mem_mb":40.2,"disk_size":"5.2G"},{"runtime":"python:3.13-slim","python_version":"3.13","os_libc":"slim (glibc)","variant":"cuda13","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.47,"mem_mb":40.2,"disk_size":"3.9G"},{"runtime":"python:3.13-slim","python_version":"3.13","os_libc":"slim (glibc)","variant":"tpu","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":2.35,"mem_mb":40.1,"disk_size":"1.4G"},{"runtime":"python:3.9-alpine","python_version":"3.9","os_libc":"alpine (musl)","variant":"cpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.9-alpine","python_version":"3.9","os_libc":"alpine (musl)","variant":"cuda12","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.9-alpine","python_version":"3.9","os_libc":"alpine (musl)","variant":"cuda13","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.9-alpine","python_version":"3.9","os_libc":"alpine (musl)","variant":"tpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null},{"runtime":"python:3.9-slim","python_version":"3.9","os_libc":"slim (glibc)","variant":"cpu","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":1.09,"mem_mb":34.5,"disk_size":"555M"},{"runtime":"python:3.9-slim","python_version":"3.9","os_libc":"slim (glibc)","variant":"cuda12","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":1.44,"mem_mb":34.5,"disk_size":"4.8G"},{"runtime":"python:3.9-slim","python_version":"3.9","os_libc":"slim (glibc)","variant":"cuda13","exit_code":0,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":1.11,"mem_mb":34.5,"disk_size":"555M"},{"runtime":"python:3.9-slim","python_version":"3.9","os_libc":"slim (glibc)","variant":"tpu","exit_code":1,"wheel_type":null,"failure_reason":null,"install_time_s":null,"import_time_s":null,"mem_mb":null,"disk_size":null}]},"quickstart_checks":{"last_tested":"2026-04-23","tag":"draft","tag_description":"notable failures across runtimes","results":[{"runtime":"python:3.10-alpine","exit_code":1},{"runtime":"python:3.10-slim","exit_code":0},{"runtime":"python:3.11-alpine","exit_code":1},{"runtime":"python:3.11-slim","exit_code":0},{"runtime":"python:3.12-alpine","exit_code":-1},{"runtime":"python:3.12-slim","exit_code":0},{"runtime":"python:3.13-alpine","exit_code":-1},{"runtime":"python:3.13-slim","exit_code":0},{"runtime":"python:3.9-alpine","exit_code":1},{"runtime":"python:3.9-slim","exit_code":0}]}}