JAX CUDA 13 Plugin

raw JSON →
0.10.0 verified Sat May 09 auth: no python

JAX plugin providing NVIDIA GPU support for CUDA 13.x. Version 0.10.0 is compatible with JAX v0.10.0. Release cadence follows JAX releases.

pip install jax-cuda13-plugin
error ImportError: libcuda.so.1: cannot open shared object file: No such file or directory
cause CUDA runtime library not found; plugin requires CUDA 13.x installed.
fix
Install CUDA 13.x from NVIDIA and ensure LD_LIBRARY_PATH includes /usr/local/cuda-13/lib64.
error jax._src.plugins.PluginNotInstalledError: jax_cuda13_plugin not installed
cause Plugin package not installed or version mismatch with JAX.
fix
Run pip install jax-cuda13-plugin==0.10.0 matching your JAX version.
error RuntimeError: Failed to initialize XLA: CUDA driver version is insufficient for CUDA runtime version
cause NVIDIA driver is too old for CUDA 13.
fix
Update NVIDIA driver to version that supports CUDA 13 (e.g., driver >= 545).
breaking Plugin version must exactly match JAX major.minor version (e.g., 0.10.x). Mismatch causes import errors.
fix Install jax-cuda13-plugin==0.10.0 alongside jax==0.10.0.
gotcha jax-cuda13-plugin only supports CUDA 13.x; it will not work with other CUDA versions. Ensure your system has CUDA 13 runtime.
fix Verify CUDA version with `nvidia-smi`; install appropriate plugin (e.g., jax-cuda12-plugin for CUDA 12).
gotcha Importing jax_cuda13_plugin is not required; merely installing it enables GPU support. The import statement is a no-op and may cause confusion.
fix Do not import plugin explicitly; just install and import jax.
pip install jax[cuda13]==0.10.0 jax-cuda13-plugin==0.10.0

Verify GPU device is detected after installing plugin.

import jax
import jax.numpy as jnp
x = jnp.array([1., 2., 3.])
print(jax.devices())