{"library":"jaxlib","title":"jaxlib","description":"jaxlib is the essential support library for JAX, containing the binary (C/C++) parts of the JAX ecosystem, including Python bindings, the XLA compiler, the PJRT runtime, and various handwritten kernels. While JAX itself is a pure Python package providing the high-level API, jaxlib acts as its compiled backend, enabling high-performance numerical computation on CPUs, GPUs, and TPUs. The current version is 0.9.2, and it follows a frequent release cadence, often aligning with or preceding JAX releases.","status":"active","version":"0.9.2","language":"en","source_language":"en","source_url":"https://github.com/jax-ml/jax","tags":["machine learning","deep learning","numerical computation","XLA","JIT","compilation","accelerators","backend"],"install":[{"cmd":"pip install --upgrade pip\npip install --upgrade jax jaxlib","lang":"bash","label":"For CPU-only (Linux/macOS/Windows)"},{"cmd":"pip install --upgrade pip\npip install --upgrade \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html","lang":"bash","label":"For NVIDIA GPU (CUDA 12, pip-installed CUDA/cuDNN)"},{"cmd":"# For specific CUDA versions or locally installed CUDA, refer to official JAX docs:\n# https://jax.readthedocs.io/en/latest/installation.html","lang":"bash","label":"Other accelerator installations"}],"dependencies":[{"reason":"jaxlib is the compiled backend for the JAX Python frontend library; they are interdependent.","package":"jax"},{"reason":"Required for array operations; minimum version 2.0 is now enforced for recent jaxlib versions.","package":"numpy"},{"reason":"Required for scientific computing functions; minimum version 1.13 is now enforced for recent jaxlib versions due to NumPy 2.0.","package":"scipy"},{"reason":"Required for machine learning specific dtypes.","package":"ml_dtypes"}],"imports":[{"note":"Users typically import `jax` (the Python frontend) to interact with the JAX ecosystem, which implicitly uses `jaxlib` as its backend. Direct imports from `jaxlib` for general user code are rare.","symbol":"jax","correct":"import jax"},{"note":"JAX's NumPy-like API (`jax.numpy`) is the primary way users interact with array operations that are compiled and executed by `jaxlib`.","symbol":"jax.numpy","correct":"import jax.numpy as jnp"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\n\ndef my_function(x):\n    return jnp.sin(x) * jnp.cos(x)\n\n# JIT-compile the function for performance\ncompiled_function = jax.jit(my_function)\n\n# Create a JAX array\nx = jnp.linspace(0, 10, 1000)\n\n# Run the compiled function\ny = compiled_function(x)\n\nprint(f\"JAX detected devices: {jax.devices()}\")\nprint(f\"Result array shape: {y.shape}\")\nprint(f\"First 5 elements of y: {y[:5]}\")","lang":"python","description":"This quickstart demonstrates a basic JAX program that implicitly leverages `jaxlib` for Just-In-Time (JIT) compilation and execution on available accelerators (CPU, GPU, or TPU). It defines a simple numerical function, compiles it with `jax.jit`, and performs an operation on a JAX array. The output shows detected devices and a sample of the computation."},"warnings":[{"fix":"Rewrite `jax.pmap` usage to `jax.shard_map`. Consult the JAX migration guide for `pmap`.","message":"The `jax.pmap` function is now in maintenance mode, and its default implementation has changed. Users are strongly encouraged to migrate new code to `jax.shard_map` for data parallelism.","severity":"breaking","affected_versions":">=0.8.0"},{"fix":"Upgrade NumPy to version 2.0 or newer and SciPy to 1.13 or newer (`pip install --upgrade numpy scipy`).","message":"The minimum supported NumPy version is now 2.0, and consequently, the minimum supported SciPy version is 1.13. Using older versions will lead to errors.","severity":"breaking","affected_versions":">=0.7.2"},{"fix":"Ensure that the input to `jax.dlpack.from_dlpack` is an object that implements the DLPack Python protocol.","message":"`jax.dlpack.from_dlpack` no longer accepts a raw DLPack capsule directly. It must now be called with an array implementing the `__dlpack__` and `__dlpack_device__` protocols.","severity":"breaking","affected_versions":">=0.7.2"},{"fix":"Adopt a functional programming style for array manipulations. Use `array = array.at[index].set(value)` for updates, which returns a new array with the modification.","message":"JAX arrays are immutable, unlike NumPy arrays. In-place modification operations common in NumPy (e.g., `arr[0] = 5`) are not supported and will raise an error or require explicit functional updates like `.at[idx].set(value)`.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Always refer to the official JAX installation guide for the correct, platform-specific commands. Ensure your CUDA/cuDNN versions are compatible with the `jaxlib` wheel you are installing.","message":"Installing `jaxlib` for NVIDIA GPUs or TPUs requires specific installation commands (e.g., `jax[cudaXX_pip]`) and often requires matching CUDA/cuDNN versions. Using `pip install jaxlib` alone will typically install a CPU-only version, and mismatched versions can lead to runtime errors or devices not being detected.","severity":"gotcha","affected_versions":"All versions"},{"fix":"If your code relies on distinguishing `TypedNdArray` from `np.ndarray`, adjust type checks accordingly. Convert to classic NumPy arrays using `np.asarray(x)` if necessary.","message":"The semi-private type `jax._src.literals.TypedNdArray` is now a subclass of `np.ndarray`, not just a duck type. This change may affect code relying on `isinstance(x, np.ndarray)` or similar type checks for JAX internal types if they were previously treated as distinct from `np.ndarray`.","severity":"deprecated","affected_versions":">=0.9.2"}],"env_vars":null,"last_verified":"2026-04-05T00:00:00.000Z","next_check":"2026-07-04T00:00:00.000Z"}