{"id":4364,"library":"jax-cuda12-plugin","title":"JAX Plugin for NVIDIA GPUs (CUDA 12)","description":"JAX is a Python library by Google for high-performance numerical computing, providing a NumPy-like interface with automatic differentiation and function transformations, capable of running on CPUs, GPUs, and TPUs. The `jax-cuda12-plugin` specifically provides NVIDIA GPU support for JAX, compatible with CUDA 12.x environments. JAX and its core library `jaxlib` (which this plugin extends) are actively maintained with frequent releases, typically on a monthly or bi-monthly schedule for minor versions.","status":"active","version":"0.9.2","language":"en","source_language":"en","source_url":"https://github.com/jax-ml/jax","tags":["machine learning","gpu","cuda","array programming","automatic differentiation","numerical computing"],"install":[{"cmd":"pip install jax-cuda12-plugin","lang":"bash","label":"Basic installation"},{"cmd":"pip install -U --pre jax jaxlib \"jax-cuda12-plugin[with-cuda]\" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html","lang":"bash","label":"Recommended installation with JAX core (latest pre-release)"}],"dependencies":[{"reason":"Core JAX library for array operations and transformations.","package":"jax"},{"reason":"JAX's compiled backend (XLA) and platform-specific code. This plugin extends jaxlib with CUDA 12.x support.","package":"jaxlib"}],"imports":[{"symbol":"jax","correct":"import jax"},{"note":"While JAX's numpy API is similar, directly aliasing standard numpy as `jnp` is incorrect and will not leverage JAX's accelerator capabilities.","wrong":"import numpy as jnp","symbol":"jax.numpy","correct":"import jax.numpy as jnp"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\n\n# Verify GPU device availability\nprint(\"Available devices:\", jax.devices())\n\n# Define a simple function\ndef f(x):\n  return jnp.sum(x**2 + 2*x + 1)\n\n# Just-in-Time compilation for performance\nf_jit = jax.jit(f)\n\n# Automatic differentiation for gradients\ngrad_f = jax.grad(f)\ngrad_f_jit = jax.jit(grad_f)\n\nx = jnp.array([1.0, 2.0, 3.0])\n\nprint(\"Original function output:\", f(x))\nprint(\"JIT compiled function output:\", f_jit(x))\nprint(\"Gradient of function:\", grad_f(x))\nprint(\"JIT compiled gradient:\", grad_f_jit(x))\n\n# Example of immutability (common gotcha):\n# Attempting x[0] = 5.0 would raise a TypeError.\n# Correct way to 'update' an array (creates a new array):\nx_new = x.at[0].set(5.0)\nprint(\"Original array (unchanged):\", x)\nprint(\"Updated array (new object):\", x_new)","lang":"python","description":"This quickstart demonstrates core JAX functionalities: utilizing the NumPy-like API (`jax.numpy`), applying Just-In-Time (JIT) compilation with `jax.jit` for performance, and computing gradients automatically using `jax.grad`. It also highlights the immutability of JAX arrays, a key difference from NumPy."},"warnings":[{"fix":"Migrate from `jax.pmap` to `jax.shard_map` or `jax.jit`. Consult the JAX migration guide for details on `pmap` to `shard_map` transitions.","message":"JAX's parallel map (`jax.pmap`) is in maintenance mode and its default implementation has changed. New code is strongly encouraged to use `jax.shard_map` or `jax.jit` for automatic parallelism. The `auto=` parameter of `jax.experimental.shard_map.shard_map` was removed in v0.8.0.","severity":"breaking","affected_versions":">=0.8.0"},{"fix":"Ensure you are using the plugin-based installation, typically by installing `jax-cuda12-plugin` (or `jax[cuda12]`) as per the official JAX installation instructions.","message":"Support for monolithic CUDA `jaxlibs` (e.g., `jaxlib==0.4.29+cuda12`) has been dropped. All CUDA support is now provided via plugin-based installations.","severity":"breaking","affected_versions":">=0.4.30"},{"fix":"Use the `.at[]` syntax for element-wise updates, which returns a new array with the changes (e.g., `x = x.at[0].set(10)`).","message":"JAX arrays are immutable. Attempting in-place modification (e.g., `x[0] = 10`) will result in a `TypeError`.","severity":"gotcha","affected_versions":"All versions"},{"fix":"To recover the previous host-based precision for narrow-width floats, explicitly cast to `jnp.array(np.arange(...))`.","message":"`jax.numpy.arange` with a specified `step` no longer generates the array on the host. This can lead to less precise outputs for narrow-width floats (e.g., bfloat16) compared to previous versions.","severity":"gotcha","affected_versions":">=0.9.2"},{"fix":"Upgrade NumPy to version 2.0 or newer (and SciPy to 1.13 or newer if used).","message":"The minimum supported NumPy version for JAX is now 2.0.","severity":"breaking","affected_versions":">=0.7.2"},{"fix":"Use `jax.lax.pcast(..., to='varying')` as the replacement.","message":"`jax.lax.pvary` has been deprecated.","severity":"deprecated","affected_versions":">=0.8.2"}],"env_vars":null,"last_verified":"2026-04-12T00:00:00.000Z","next_check":"2026-07-11T00:00:00.000Z"}