Tokamax: A GPU and TPU Custom Kernel Library
Tokamax is an OpenXLA library providing high-performance custom accelerator kernels for NVIDIA GPUs and Google TPUs. It offers state-of-the-art implementations built on top of JAX and Pallas, along with tooling for users to build and autotune their own custom kernels. As of version 0.0.12, it is still under heavy development, and users should anticipate API changes.
Common errors
-
jax.errors.JAXTypeError: Custom call '...' not allowed in an exported function.
cause Attempting to export a JAX function containing `tokamax` kernels using `jax.export` without disabling the necessary checks.fixPass `disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS` to `jax.export`. Example: `f_exported = jax.export.export(f_grad, disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS)`. -
tokamax.exceptions.UnsupportedImplementationError: Unsupported implementation for kernel 'layer_norm': mosaic (e.g., FP64 inputs are not supported).
cause You explicitly requested a kernel implementation (e.g., 'mosaic') that does not support the current input data types (e.g., `jnp.float64`) or hardware configuration.fixEither change your input data type to a supported one (e.g., `jnp.bfloat16`, `jnp.float32`), or set `implementation=None` to allow Tokamax to automatically select a compatible backend. -
AttributeError: module 'tokamax' has no attribute 'some_function_name'
cause You are likely using an outdated API call. Tokamax is in active development, and function names or modules might have changed between minor versions.fixUpdate your `tokamax` library to the latest version (`pip install -U tokamax`) and consult the official GitHub repository's README or source code for the most current API. If necessary, pin a specific working version.
Warnings
- breaking Tokamax is still heavily under development. Incomplete features and API changes are to be expected, especially given its pre-1.0 version number.
- gotcha Autotuning kernels with `tokamax.autotune` is fundamentally non-deterministic due to noisy kernel execution time measurements. Different configurations chosen during autotuning can lead to numerical non-determinism.
- gotcha When exporting JAX functions containing Tokamax kernels using `jax.export`, you must disable export checks by passing `disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS`. Without this, JAX will prevent custom calls from being exported. Functions serialized this way also lose the device-independence of standard StableHLO.
- gotcha Specifying a particular `implementation` for a kernel (e.g., `implementation="mosaic"`) can lead to exceptions if that implementation is unsupported for the given inputs (e.g., FP64 inputs) or hardware (e.g., older GPUs).
Install
-
pip install -U tokamax -
pip install git+https://github.com/openxla/tokamax.git
Imports
- tokamax
import tokamax
- jax
import jax
- jax.numpy
import jax.numpy as jnp
- layer_norm
tokamax.layer_norm
- dot_product_attention
tokamax.dot_product_attention
- autotune
tokamax.autotune
- standardize_function
tokamax.standardize_function
- benchmark
tokamax.benchmark
Quickstart
import jax
import jax.numpy as jnp
import tokamax
def loss_function(x, scale):
# Apply layer normalization with a Triton implementation
x = tokamax.layer_norm(
x, scale=scale, offset=None, implementation="triton"
)
# Apply dot product attention, allowing Tokamax to select the best implementation
x = tokamax.dot_product_attention(x, x, x, implementation=None)
return jnp.sum(x)
# Example usage with JAX JIT and Grad
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 2048, 64), dtype=jnp.bfloat16)
scale = jax.random.normal(key, (64,), dtype=jnp.bfloat16)
f_grad = jax.jit(jax.grad(loss_function))
output_grad = f_grad(x, scale)
print("Computed gradient successfully.")
# Example of autotuning (requires compatible hardware)
# autotune_result = tokamax.autotune(loss_function, x, scale)
# with autotune_result:
# out_autotuned = f_grad(x, scale)
# print("Autotuned output successfully.")