jaxlib
jaxlib is the essential support library for JAX, containing the binary (C/C++) parts of the JAX ecosystem, including Python bindings, the XLA compiler, the PJRT runtime, and various handwritten kernels. While JAX itself is a pure Python package providing the high-level API, jaxlib acts as its compiled backend, enabling high-performance numerical computation on CPUs, GPUs, and TPUs. The current version is 0.9.2, and it follows a frequent release cadence, often aligning with or preceding JAX releases.
Common errors
-
ModuleNotFoundError: No module named 'jaxlib'
cause This error occurs when the `jaxlib` package, which contains the compiled backend for JAX, is not successfully installed or is not accessible in the current Python environment, even if `jax` itself might be installed.fixEnsure `jaxlib` is properly installed, often by installing JAX with the appropriate backend (CPU, CUDA, ROCm, or TPU) to automatically pull the correct `jaxlib` version. It's recommended to upgrade `pip` and then install `jax` using the specific installation commands from the official JAX documentation for your desired accelerator. For CPU: `pip install --upgrade pip jax`. For NVIDIA GPU (CUDA 12): `pip install --upgrade pip 'jax[cuda12]'`. -
ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none)
cause This installation error indicates that `pip` cannot find a `jaxlib` wheel (pre-compiled package) compatible with your Python version, operating system, or specific hardware (like CUDA/cuDNN versions if installing for GPU). This is particularly common for less common system configurations, older Python versions, or specific CUDA setups.fixVerify your Python version and ensure it's supported by JAX. Consult the official JAX installation guide for compatible Python, CUDA, and cuDNN versions, and the correct `pip install` command, potentially including a specific JAX release URL if pre-built wheels are not available for your exact setup. Ensure your `pip` is up-to-date: `pip install --upgrade pip`. Consider using a virtual environment and a supported Python version (e.g., Python 3.9 or newer). -
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.
cause This error means your CPU lacks support for AVX (Advanced Vector Extensions) instructions, which are required by many pre-built `jaxlib` wheels since version 0.1.62.fixYou can try building `jaxlib` from source with specific flags to disable AVX, or use an older `jaxlib` version that predates the AVX requirement (though this is not officially supported or tested and may lead to other compatibility issues). The recommended approach for older hardware without AVX is to build from source, which requires a C++ toolchain. -
WARNING: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
cause This warning (or a related `jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed`) occurs when JAX detects an NVIDIA GPU but the installed `jaxlib` package is either a CPU-only version or is incompatible with your CUDA toolkit, cuDNN libraries, or GPU drivers.fixInstall the correct GPU-enabled `jaxlib` version that matches your CUDA toolkit, cuDNN, and GPU driver. It is highly recommended to follow the official JAX installation instructions explicitly, using a command like `pip install --upgrade 'jax[cuda12]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` for NVIDIA GPUs (adjusting `cuda12` for your specific CUDA version). Ensure your CUDA drivers are up-to-date and compatible with the CUDA toolkit version you intend to use. Avoid setting `LD_LIBRARY_PATH` as it can interfere with JAX's library discovery. -
AttributeError: module 'jax' has no attribute 'version'
cause This `AttributeError` typically signals a version mismatch between the installed `jax` and `jaxlib` packages. `jax` expects certain attributes or functionalities from `jaxlib` that are not present or have changed in the installed `jaxlib` version, or vice-versa.fixEnsure that your `jax` and `jaxlib` packages are compatible. The most reliable fix is to uninstall both `jax` and `jaxlib` and then reinstall `jax` using the recommended installation command for your system and desired accelerator, which will automatically install a compatible `jaxlib`. For example: `pip uninstall jax jaxlib` then `pip install -U 'jax[cpu]'` or `pip install -U 'jax[cuda12]'`.
Warnings
- breaking The `jax.pmap` function is now in maintenance mode, and its default implementation has changed. Users are strongly encouraged to migrate new code to `jax.shard_map` for data parallelism.
- breaking The minimum supported NumPy version is now 2.0, and consequently, the minimum supported SciPy version is 1.13. Using older versions will lead to errors.
- breaking `jax.dlpack.from_dlpack` no longer accepts a raw DLPack capsule directly. It must now be called with an array implementing the `__dlpack__` and `__dlpack_device__` protocols.
- gotcha JAX arrays are immutable, unlike NumPy arrays. In-place modification operations common in NumPy (e.g., `arr[0] = 5`) are not supported and will raise an error or require explicit functional updates like `.at[idx].set(value)`.
- gotcha Installing `jaxlib` for NVIDIA GPUs or TPUs requires specific installation commands (e.g., `jax[cudaXX_pip]`) and often requires matching CUDA/cuDNN versions. Using `pip install jaxlib` alone will typically install a CPU-only version, and mismatched versions can lead to runtime errors or devices not being detected.
- deprecated The semi-private type `jax._src.literals.TypedNdArray` is now a subclass of `np.ndarray`, not just a duck type. This change may affect code relying on `isinstance(x, np.ndarray)` or similar type checks for JAX internal types if they were previously treated as distinct from `np.ndarray`.
Install
-
pip install --upgrade pip pip install --upgrade jax jaxlib -
pip install --upgrade pip pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -
# For specific CUDA versions or locally installed CUDA, refer to official JAX docs: # https://jax.readthedocs.io/en/latest/installation.html
Imports
- jax
import jax
- jax.numpy
import jax.numpy as jnp
Quickstart
import jax
import jax.numpy as jnp
def my_function(x):
return jnp.sin(x) * jnp.cos(x)
# JIT-compile the function for performance
compiled_function = jax.jit(my_function)
# Create a JAX array
x = jnp.linspace(0, 10, 1000)
# Run the compiled function
y = compiled_function(x)
print(f"JAX detected devices: {jax.devices()}")
print(f"Result array shape: {y.shape}")
print(f"First 5 elements of y: {y[:5]}")