JAX Plugin for NVIDIA GPUs (CUDA 12)
JAX is a Python library by Google for high-performance numerical computing, providing a NumPy-like interface with automatic differentiation and function transformations, capable of running on CPUs, GPUs, and TPUs. The `jax-cuda12-plugin` specifically provides NVIDIA GPU support for JAX, compatible with CUDA 12.x environments. JAX and its core library `jaxlib` (which this plugin extends) are actively maintained with frequent releases, typically on a monthly or bi-monthly schedule for minor versions.
Warnings
- breaking JAX's parallel map (`jax.pmap`) is in maintenance mode and its default implementation has changed. New code is strongly encouraged to use `jax.shard_map` or `jax.jit` for automatic parallelism. The `auto=` parameter of `jax.experimental.shard_map.shard_map` was removed in v0.8.0.
- breaking Support for monolithic CUDA `jaxlibs` (e.g., `jaxlib==0.4.29+cuda12`) has been dropped. All CUDA support is now provided via plugin-based installations.
- gotcha JAX arrays are immutable. Attempting in-place modification (e.g., `x[0] = 10`) will result in a `TypeError`.
- gotcha `jax.numpy.arange` with a specified `step` no longer generates the array on the host. This can lead to less precise outputs for narrow-width floats (e.g., bfloat16) compared to previous versions.
- breaking The minimum supported NumPy version for JAX is now 2.0.
- deprecated `jax.lax.pvary` has been deprecated.
Install
-
pip install jax-cuda12-plugin -
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Imports
- jax
import jax
- jax.numpy
import jax.numpy as jnp
Quickstart
import jax
import jax.numpy as jnp
# Verify GPU device availability
print("Available devices:", jax.devices())
# Define a simple function
def f(x):
return jnp.sum(x**2 + 2*x + 1)
# Just-in-Time compilation for performance
f_jit = jax.jit(f)
# Automatic differentiation for gradients
grad_f = jax.grad(f)
grad_f_jit = jax.jit(grad_f)
x = jnp.array([1.0, 2.0, 3.0])
print("Original function output:", f(x))
print("JIT compiled function output:", f_jit(x))
print("Gradient of function:", grad_f(x))
print("JIT compiled gradient:", grad_f_jit(x))
# Example of immutability (common gotcha):
# Attempting x[0] = 5.0 would raise a TypeError.
# Correct way to 'update' an array (creates a new array):
x_new = x.at[0].set(5.0)
print("Original array (unchanged):", x)
print("Updated array (new object):", x_new)