Tokamax: A GPU and TPU Custom Kernel Library

0.0.12 · active · verified Thu Apr 16

Tokamax is an OpenXLA library providing high-performance custom accelerator kernels for NVIDIA GPUs and Google TPUs. It offers state-of-the-art implementations built on top of JAX and Pallas, along with tooling for users to build and autotune their own custom kernels. As of version 0.0.12, it is still under heavy development, and users should anticipate API changes.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the application of `tokamax` custom kernels (e.g., `layer_norm`, `dot_product_attention`) within a JAX computation graph. It shows how to specify kernel implementations or allow `tokamax` to select the best one. It also highlights the pattern for integrating with JAX's `jit` and `grad` transformations.

import jax
import jax.numpy as jnp
import tokamax

def loss_function(x, scale):
    # Apply layer normalization with a Triton implementation
    x = tokamax.layer_norm(
        x, scale=scale, offset=None, implementation="triton"
    )
    # Apply dot product attention, allowing Tokamax to select the best implementation
    x = tokamax.dot_product_attention(x, x, x, implementation=None)
    return jnp.sum(x)

# Example usage with JAX JIT and Grad
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 2048, 64), dtype=jnp.bfloat16)
scale = jax.random.normal(key, (64,), dtype=jnp.bfloat16)

f_grad = jax.jit(jax.grad(loss_function))
output_grad = f_grad(x, scale)
print("Computed gradient successfully.")

# Example of autotuning (requires compatible hardware)
# autotune_result = tokamax.autotune(loss_function, x, scale)
# with autotune_result:
#    out_autotuned = f_grad(x, scale)
#    print("Autotuned output successfully.")

view raw JSON →