{"id":27880,"library":"jax-cuda13-pjrt","title":"JAX CUDA 13 PJRT Plugin","description":"JAX PJRT plugin for NVIDIA GPUs using CUDA 13. Provides the XLA compiler backend for JAX on NVIDIA hardware. Version 0.10.0 aligns with JAX v0.10.0. Release cadence follows JAX releases.","status":"active","version":"0.10.0","language":"python","source_language":"en","source_url":"https://github.com/jax-ml/jax","tags":["jax","cuda13","pjrt","nvidia","gpu","xla"],"install":[{"cmd":"pip install jax-cuda13-pjrt","lang":"bash","label":"Install from PyPI"}],"dependencies":[{"reason":"JAX core is required; the plugin is a backend.","package":"jax","optional":false},{"reason":"Required for XLA runtime; must match CUDA version.","package":"jaxlib","optional":false}],"imports":[{"note":"Plugins are typically imported for side effects; after import, JAX auto-detects the plugin.","wrong":"","symbol":"jax_cuda13_pjrt","correct":"import jax_cuda13_pjrt"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport jax_cuda13_pjrt  # noqa: F401, ensure NVIDIA GPU plugin is registered\n\n# Verify device is visible\nprint(jax.devices())\n\n# Simple computation\nx = jnp.array([1, 2, 3])\ny = jnp.square(x)\nprint(y)","lang":"python","description":"Import the plugin to register the CUDA 13 backend, then use JAX as usual."},"warnings":[{"fix":"Always import jax_cuda13_pjrt at the top of your script.","message":"The plugin must be imported before any JAX computation to ensure the GPU backend is selected.","severity":"gotcha","affected_versions":"all"},{"fix":"Ensure jax==0.10.0 and jax-cuda13-pjrt==0.10.0 are installed together.","message":"Plugin version must match the JAX version. Using mismatched versions may cause runtime errors or undefined behavior.","severity":"breaking","affected_versions":"all"},{"fix":"Verify your CUDA version with nvcc --version. If CUDA < 13, use jax-cuda12-pjrt instead.","message":"The plugin is only for CUDA 13.x; it will not work with older CUDA toolkits.","severity":"gotcha","affected_versions":"all"}],"env_vars":null,"last_verified":"2026-05-09T00:00:00.000Z","next_check":"2026-08-07T00:00:00.000Z","problems":[{"fix":"Add 'import jax_cuda13_pjrt' as the first JAX-related import, and ensure jax and jax-cuda13-pjrt versions match.","cause":"Plugin not imported before JAX initializes, or incompatible JAX version.","error":"RuntimeError: Unable to find backend for plugin jax_cuda13_pjrt"},{"fix":"Check CUDA installation (nvcc --version) and import the plugin: import jax_cuda13_pjrt","cause":"CUDA 13 not installed or the plugin not imported.","error":"No GPU/TPU found, falling back to CPU."}],"ecosystem":"pypi","meta_description":null,"install_score":null,"install_tag":null,"quickstart_score":null,"quickstart_tag":null}