JAX CUDA 12 PJRT Plugin
The jax-cuda12-pjrt package provides the JAX XLA PJRT backend for NVIDIA GPUs, specifically built with CUDA 12. It serves as the `jaxlib` implementation when GPU acceleration is desired. The current version is 0.9.2, and JAX along with its ecosystem components typically follow a rapid release cadence, often with monthly or bi-monthly updates.
Warnings
- breaking The default `jax.pmap` implementation has changed, and `jax.pmap` is now in maintenance mode. Users are encouraged to migrate to `jax.shard_map` for new code and distributed computations.
- breaking `jax.dlpack.from_dlpack` no longer accepts a DLPack capsule directly. It now requires an array implementing `__dlpack__` and `__dlpack_device__`.
- deprecated `jax.lax.pvary` has been deprecated.
- gotcha `jax.numpy.arange` with a `step` argument no longer generates the array on the host by default. This change improves efficiency but can lead to less precise outputs for narrow-width floats (e.g., bfloat16).
- gotcha When using `jax.shard_map` in 'Explicit' mode, JAX will now raise an error if the `PartitionSpec` of an input does not match the `PartitionSpec` specified in `in_specs`. Previously, this might have silently caused an implicit reshard.
- gotcha Installing `jaxlib` (the CPU-only version) alongside a GPU-specific `jaxlib` variant like `jax-cuda12-pjrt` can lead to conflicts, unexpected device selection, or errors. Only one `jaxlib` implementation should be installed.
Install
-
pip install jax jax-cuda12-pjrt
Imports
- jax
import jax
Quickstart
import jax
import jax.numpy as jnp
# Check for available devices
print(f"JAX backend: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")
# Define a JIT-compiled function
@jax.jit
def sum_of_squares(x):
return jnp.sum(x**2)
# Create some data
data = jnp.array([1.0, 2.0, 3.0, 4.0])
# Run the function
result = sum_of_squares(data)
print(f"Input data: {data}")
print(f"Result (sum of squares): {result}")
# Verify it's on a device if available
if jax.devices('gpu'):
print(f"Result device: {result.device()}")