{"id":4363,"library":"jax-cuda12-pjrt","title":"JAX CUDA 12 PJRT Plugin","description":"The jax-cuda12-pjrt package provides the JAX XLA PJRT backend for NVIDIA GPUs, specifically built with CUDA 12. It serves as the `jaxlib` implementation when GPU acceleration is desired. The current version is 0.9.2, and JAX along with its ecosystem components typically follow a rapid release cadence, often with monthly or bi-monthly updates.","status":"active","version":"0.9.2","language":"en","source_language":"en","source_url":"https://github.com/jax-ml/jax","tags":["deep learning","machine learning","gpu","cuda","xla","numpy","auto-differentiation","scientific computing"],"install":[{"cmd":"pip install jax jax-cuda12-pjrt","lang":"bash","label":"Install JAX with CUDA 12 GPU support"}],"dependencies":[{"reason":"This package provides the GPU backend for the core JAX library. JAX itself is required for usage.","package":"jax"}],"imports":[{"note":"The `jax-cuda12-pjrt` package provides the backend implementation (jaxlib) and is not directly imported into user code. All interactions are through the core `jax` library.","wrong":"import jax_cuda12_pjrt","symbol":"jax","correct":"import jax"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\n\n# Check for available devices\nprint(f\"JAX backend: {jax.default_backend()}\")\nprint(f\"Available devices: {jax.devices()}\")\n\n# Define a JIT-compiled function\n@jax.jit\ndef sum_of_squares(x):\n  return jnp.sum(x**2)\n\n# Create some data\ndata = jnp.array([1.0, 2.0, 3.0, 4.0])\n\n# Run the function\nresult = sum_of_squares(data)\nprint(f\"Input data: {data}\")\nprint(f\"Result (sum of squares): {result}\")\n\n# Verify it's on a device if available\nif jax.devices('gpu'):\n    print(f\"Result device: {result.device()}\")","lang":"python","description":"This quickstart demonstrates basic JAX usage. It checks the JAX backend and available devices, then defines and executes a JIT-compiled function, confirming that the computation leverages GPU acceleration if available."},"warnings":[{"fix":"Refer to the JAX migration guide for `pmap` (docs.jax.dev/en/latest/migrate_pmap.html) and use `jax.shard_map` instead of `jax.pmap`.","message":"The default `jax.pmap` implementation has changed, and `jax.pmap` is now in maintenance mode. Users are encouraged to migrate to `jax.shard_map` for new code and distributed computations.","severity":"breaking","affected_versions":">=0.8.0"},{"fix":"Update code to pass an array-like object that implements the DLPack protocol, rather than a raw capsule, to `jax.dlpack.from_dlpack`.","message":"`jax.dlpack.from_dlpack` no longer accepts a DLPack capsule directly. It now requires an array implementing `__dlpack__` and `__dlpack_device__`.","severity":"breaking","affected_versions":">=0.7.2"},{"fix":"Use `jax.lax.pcast(..., to='varying')` as the replacement for `jax.lax.pvary`.","message":"`jax.lax.pvary` has been deprecated.","severity":"deprecated","affected_versions":">=0.8.2"},{"fix":"To recover the previous host-based generation and ensure higher precision for narrow-width floats, explicitly cast the NumPy output: `jnp.array(np.arange(...))`.","message":"`jax.numpy.arange` with a `step` argument no longer generates the array on the host by default. This change improves efficiency but can lead to less precise outputs for narrow-width floats (e.g., bfloat16).","severity":"gotcha","affected_versions":">=0.9.2"},{"fix":"Ensure that the `PartitionSpec` of your inputs precisely matches the `in_specs` argument when using `jax.shard_map` in 'Explicit' mode. Omit `in_specs` if you intend for `shard_map` to infer the partitioning.","message":"When using `jax.shard_map` in 'Explicit' mode, JAX will now raise an error if the `PartitionSpec` of an input does not match the `PartitionSpec` specified in `in_specs`. Previously, this might have silently caused an implicit reshard.","severity":"gotcha","affected_versions":">=0.9.1"},{"fix":"Ensure only `jax` and your desired `jaxlib` variant (e.g., `jax-cuda12-pjrt` for CUDA 12 GPU or `jaxlib` for CPU) are installed. Uninstall any conflicting `jaxlib` packages before installing your target GPU backend.","message":"Installing `jaxlib` (the CPU-only version) alongside a GPU-specific `jaxlib` variant like `jax-cuda12-pjrt` can lead to conflicts, unexpected device selection, or errors. Only one `jaxlib` implementation should be installed.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-12T00:00:00.000Z","next_check":"2026-07-11T00:00:00.000Z"}