JAX CUDA 12 PJRT Plugin

0.9.2 · active · verified Sun Apr 12

The jax-cuda12-pjrt package provides the JAX XLA PJRT backend for NVIDIA GPUs, specifically built with CUDA 12. It serves as the `jaxlib` implementation when GPU acceleration is desired. The current version is 0.9.2, and JAX along with its ecosystem components typically follow a rapid release cadence, often with monthly or bi-monthly updates.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates basic JAX usage. It checks the JAX backend and available devices, then defines and executes a JIT-compiled function, confirming that the computation leverages GPU acceleration if available.

import jax
import jax.numpy as jnp

# Check for available devices
print(f"JAX backend: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")

# Define a JIT-compiled function
@jax.jit
def sum_of_squares(x):
  return jnp.sum(x**2)

# Create some data
data = jnp.array([1.0, 2.0, 3.0, 4.0])

# Run the function
result = sum_of_squares(data)
print(f"Input data: {data}")
print(f"Result (sum of squares): {result}")

# Verify it's on a device if available
if jax.devices('gpu'):
    print(f"Result device: {result.device()}")

view raw JSON →