JAX Plugin for NVIDIA GPUs (CUDA 12)

0.9.2 · active · verified Sun Apr 12

JAX is a Python library by Google for high-performance numerical computing, providing a NumPy-like interface with automatic differentiation and function transformations, capable of running on CPUs, GPUs, and TPUs. The `jax-cuda12-plugin` specifically provides NVIDIA GPU support for JAX, compatible with CUDA 12.x environments. JAX and its core library `jaxlib` (which this plugin extends) are actively maintained with frequent releases, typically on a monthly or bi-monthly schedule for minor versions.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates core JAX functionalities: utilizing the NumPy-like API (`jax.numpy`), applying Just-In-Time (JIT) compilation with `jax.jit` for performance, and computing gradients automatically using `jax.grad`. It also highlights the immutability of JAX arrays, a key difference from NumPy.

import jax
import jax.numpy as jnp

# Verify GPU device availability
print("Available devices:", jax.devices())

# Define a simple function
def f(x):
  return jnp.sum(x**2 + 2*x + 1)

# Just-in-Time compilation for performance
f_jit = jax.jit(f)

# Automatic differentiation for gradients
grad_f = jax.grad(f)
grad_f_jit = jax.jit(grad_f)

x = jnp.array([1.0, 2.0, 3.0])

print("Original function output:", f(x))
print("JIT compiled function output:", f_jit(x))
print("Gradient of function:", grad_f(x))
print("JIT compiled gradient:", grad_f_jit(x))

# Example of immutability (common gotcha):
# Attempting x[0] = 5.0 would raise a TypeError.
# Correct way to 'update' an array (creates a new array):
x_new = x.at[0].set(5.0)
print("Original array (unchanged):", x)
print("Updated array (new object):", x_new)

view raw JSON →