jaxlib
jaxlib is the essential support library for JAX, containing the binary (C/C++) parts of the JAX ecosystem, including Python bindings, the XLA compiler, the PJRT runtime, and various handwritten kernels. While JAX itself is a pure Python package providing the high-level API, jaxlib acts as its compiled backend, enabling high-performance numerical computation on CPUs, GPUs, and TPUs. The current version is 0.9.2, and it follows a frequent release cadence, often aligning with or preceding JAX releases.
Warnings
- breaking The `jax.pmap` function is now in maintenance mode, and its default implementation has changed. Users are strongly encouraged to migrate new code to `jax.shard_map` for data parallelism.
- breaking The minimum supported NumPy version is now 2.0, and consequently, the minimum supported SciPy version is 1.13. Using older versions will lead to errors.
- breaking `jax.dlpack.from_dlpack` no longer accepts a raw DLPack capsule directly. It must now be called with an array implementing the `__dlpack__` and `__dlpack_device__` protocols.
- gotcha JAX arrays are immutable, unlike NumPy arrays. In-place modification operations common in NumPy (e.g., `arr[0] = 5`) are not supported and will raise an error or require explicit functional updates like `.at[idx].set(value)`.
- gotcha Installing `jaxlib` for NVIDIA GPUs or TPUs requires specific installation commands (e.g., `jax[cudaXX_pip]`) and often requires matching CUDA/cuDNN versions. Using `pip install jaxlib` alone will typically install a CPU-only version, and mismatched versions can lead to runtime errors or devices not being detected.
- deprecated The semi-private type `jax._src.literals.TypedNdArray` is now a subclass of `np.ndarray`, not just a duck type. This change may affect code relying on `isinstance(x, np.ndarray)` or similar type checks for JAX internal types if they were previously treated as distinct from `np.ndarray`.
Install
-
pip install --upgrade pip pip install --upgrade jax jaxlib -
pip install --upgrade pip pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -
# For specific CUDA versions or locally installed CUDA, refer to official JAX docs: # https://jax.readthedocs.io/en/latest/installation.html
Imports
- jax
import jax
- jax.numpy
import jax.numpy as jnp
Quickstart
import jax
import jax.numpy as jnp
def my_function(x):
return jnp.sin(x) * jnp.cos(x)
# JIT-compile the function for performance
compiled_function = jax.jit(my_function)
# Create a JAX array
x = jnp.linspace(0, 10, 1000)
# Run the compiled function
y = compiled_function(x)
print(f"JAX detected devices: {jax.devices()}")
print(f"Result array shape: {y.shape}")
print(f"First 5 elements of y: {y[:5]}")