jaxlib

0.9.2 · active · verified Sun Apr 05

jaxlib is the essential support library for JAX, containing the binary (C/C++) parts of the JAX ecosystem, including Python bindings, the XLA compiler, the PJRT runtime, and various handwritten kernels. While JAX itself is a pure Python package providing the high-level API, jaxlib acts as its compiled backend, enabling high-performance numerical computation on CPUs, GPUs, and TPUs. The current version is 0.9.2, and it follows a frequent release cadence, often aligning with or preceding JAX releases.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates a basic JAX program that implicitly leverages `jaxlib` for Just-In-Time (JIT) compilation and execution on available accelerators (CPU, GPU, or TPU). It defines a simple numerical function, compiles it with `jax.jit`, and performs an operation on a JAX array. The output shows detected devices and a sample of the computation.

import jax
import jax.numpy as jnp

def my_function(x):
    return jnp.sin(x) * jnp.cos(x)

# JIT-compile the function for performance
compiled_function = jax.jit(my_function)

# Create a JAX array
x = jnp.linspace(0, 10, 1000)

# Run the compiled function
y = compiled_function(x)

print(f"JAX detected devices: {jax.devices()}")
print(f"Result array shape: {y.shape}")
print(f"First 5 elements of y: {y[:5]}")

view raw JSON →