JAX CUDA 13 Plugin
raw JSON → 0.10.0 verified Sat May 09 auth: no python
JAX plugin providing NVIDIA GPU support for CUDA 13.x. Version 0.10.0 is compatible with JAX v0.10.0. Release cadence follows JAX releases.
pip install jax-cuda13-plugin Common errors
error ImportError: libcuda.so.1: cannot open shared object file: No such file or directory ↓
cause CUDA runtime library not found; plugin requires CUDA 13.x installed.
fix
Install CUDA 13.x from NVIDIA and ensure LD_LIBRARY_PATH includes /usr/local/cuda-13/lib64.
error jax._src.plugins.PluginNotInstalledError: jax_cuda13_plugin not installed ↓
cause Plugin package not installed or version mismatch with JAX.
fix
Run
pip install jax-cuda13-plugin==0.10.0 matching your JAX version. error RuntimeError: Failed to initialize XLA: CUDA driver version is insufficient for CUDA runtime version ↓
cause NVIDIA driver is too old for CUDA 13.
fix
Update NVIDIA driver to version that supports CUDA 13 (e.g., driver >= 545).
Warnings
breaking Plugin version must exactly match JAX major.minor version (e.g., 0.10.x). Mismatch causes import errors. ↓
fix Install jax-cuda13-plugin==0.10.0 alongside jax==0.10.0.
gotcha jax-cuda13-plugin only supports CUDA 13.x; it will not work with other CUDA versions. Ensure your system has CUDA 13 runtime. ↓
fix Verify CUDA version with `nvidia-smi`; install appropriate plugin (e.g., jax-cuda12-plugin for CUDA 12).
gotcha Importing jax_cuda13_plugin is not required; merely installing it enables GPU support. The import statement is a no-op and may cause confusion. ↓
fix Do not import plugin explicitly; just install and import jax.
Install
pip install jax[cuda13]==0.10.0 jax-cuda13-plugin==0.10.0 Imports
- jax_cuda13_plugin
import jax_cuda13_plugin
Quickstart
import jax
import jax.numpy as jnp
x = jnp.array([1., 2., 3.])
print(jax.devices())