JAX CUDA 13 PJRT Plugin
raw JSON → 0.10.0 verified Sat May 09 auth: no python
JAX PJRT plugin for NVIDIA GPUs using CUDA 13. Provides the XLA compiler backend for JAX on NVIDIA hardware. Version 0.10.0 aligns with JAX v0.10.0. Release cadence follows JAX releases.
pip install jax-cuda13-pjrt Common errors
error RuntimeError: Unable to find backend for plugin jax_cuda13_pjrt ↓
cause Plugin not imported before JAX initializes, or incompatible JAX version.
fix
Add 'import jax_cuda13_pjrt' as the first JAX-related import, and ensure jax and jax-cuda13-pjrt versions match.
error No GPU/TPU found, falling back to CPU. ↓
cause CUDA 13 not installed or the plugin not imported.
fix
Check CUDA installation (nvcc --version) and import the plugin: import jax_cuda13_pjrt
Warnings
gotcha The plugin must be imported before any JAX computation to ensure the GPU backend is selected. ↓
fix Always import jax_cuda13_pjrt at the top of your script.
breaking Plugin version must match the JAX version. Using mismatched versions may cause runtime errors or undefined behavior. ↓
fix Ensure jax==0.10.0 and jax-cuda13-pjrt==0.10.0 are installed together.
gotcha The plugin is only for CUDA 13.x; it will not work with older CUDA toolkits. ↓
fix Verify your CUDA version with nvcc --version. If CUDA < 13, use jax-cuda12-pjrt instead.
Imports
- jax_cuda13_pjrt
import jax_cuda13_pjrt
Quickstart
import jax
import jax.numpy as jnp
import jax_cuda13_pjrt # noqa: F401, ensure NVIDIA GPU plugin is registered
# Verify device is visible
print(jax.devices())
# Simple computation
x = jnp.array([1, 2, 3])
y = jnp.square(x)
print(y)