{"id":8718,"library":"tokamax","title":"Tokamax: A GPU and TPU Custom Kernel Library","description":"Tokamax is an OpenXLA library providing high-performance custom accelerator kernels for NVIDIA GPUs and Google TPUs. It offers state-of-the-art implementations built on top of JAX and Pallas, along with tooling for users to build and autotune their own custom kernels. As of version 0.0.12, it is still under heavy development, and users should anticipate API changes.","status":"active","version":"0.0.12","language":"en","source_language":"en","source_url":"https://github.com/openxla/tokamax","tags":["JAX","GPU","TPU","Custom Kernels","Deep Learning","Performance","Pallas","OpenXLA"],"install":[{"cmd":"pip install -U tokamax","lang":"bash","label":"Install latest PyPI release"},{"cmd":"pip install git+https://github.com/openxla/tokamax.git","lang":"bash","label":"Install bleeding-edge from GitHub (no stability guarantees)"}],"dependencies":[{"reason":"Type annotations for JAX arrays","package":"jaxtyping"},{"reason":"Array manipulation","package":"einshape"},{"reason":"Progress bars","package":"tqdm"},{"reason":"Immutable dictionary types","package":"immutabledict"},{"reason":"Backports of new typing features","package":"typing-extensions"},{"reason":"Core dependency for JAX-based computation","package":"jax"},{"reason":"Runtime type checking","package":"typeguard"},{"reason":"JAX's compiled XLA operations","package":"jaxlib"},{"reason":"Likely internal or specific utility","package":"qwix"},{"reason":"Abseil Python Common Libraries","package":"absl-py"},{"reason":"Data validation and settings management","package":"pydantic"},{"reason":"TensorBoard logging utilities","package":"tensorboardx"}],"imports":[{"symbol":"tokamax","correct":"import tokamax"},{"symbol":"jax","correct":"import jax"},{"symbol":"jax.numpy","correct":"import jax.numpy as jnp"},{"symbol":"layer_norm","correct":"tokamax.layer_norm"},{"symbol":"dot_product_attention","correct":"tokamax.dot_product_attention"},{"symbol":"autotune","correct":"tokamax.autotune"},{"symbol":"standardize_function","correct":"tokamax.standardize_function"},{"symbol":"benchmark","correct":"tokamax.benchmark"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport tokamax\n\ndef loss_function(x, scale):\n    # Apply layer normalization with a Triton implementation\n    x = tokamax.layer_norm(\n        x, scale=scale, offset=None, implementation=\"triton\"\n    )\n    # Apply dot product attention, allowing Tokamax to select the best implementation\n    x = tokamax.dot_product_attention(x, x, x, implementation=None)\n    return jnp.sum(x)\n\n# Example usage with JAX JIT and Grad\nkey = jax.random.PRNGKey(0)\nx = jax.random.normal(key, (32, 2048, 64), dtype=jnp.bfloat16)\nscale = jax.random.normal(key, (64,), dtype=jnp.bfloat16)\n\nf_grad = jax.jit(jax.grad(loss_function))\noutput_grad = f_grad(x, scale)\nprint(\"Computed gradient successfully.\")\n\n# Example of autotuning (requires compatible hardware)\n# autotune_result = tokamax.autotune(loss_function, x, scale)\n# with autotune_result:\n#    out_autotuned = f_grad(x, scale)\n#    print(\"Autotuned output successfully.\")","lang":"python","description":"This quickstart demonstrates the application of `tokamax` custom kernels (e.g., `layer_norm`, `dot_product_attention`) within a JAX computation graph. It shows how to specify kernel implementations or allow `tokamax` to select the best one. It also highlights the pattern for integrating with JAX's `jit` and `grad` transformations."},"warnings":[{"fix":"Refer to the latest GitHub README and documentation for current API usage. Pin specific versions to avoid unexpected breakage in production environments.","message":"Tokamax is still heavily under development. Incomplete features and API changes are to be expected, especially given its pre-1.0 version number.","severity":"breaking","affected_versions":"<1.0.0"},{"fix":"To ensure consistent numerics across sessions, serialize and explicitly reuse autotuning results using `autotune_result.dumps()` and `tokamax.AutotuningResult.loads()`.","message":"Autotuning kernels with `tokamax.autotune` is fundamentally non-deterministic due to noisy kernel execution time measurements. Different configurations chosen during autotuning can lead to numerical non-determinism.","severity":"gotcha","affected_versions":"All"},{"fix":"Add `disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS` to your `jax.export` call. Be aware that the exported function will be specific to the device it was serialized on.","message":"When exporting JAX functions containing Tokamax kernels using `jax.export`, you must disable export checks by passing `disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS`. Without this, JAX will prevent custom calls from being exported. Functions serialized this way also lose the device-independence of standard StableHLO.","severity":"gotcha","affected_versions":"All"},{"fix":"If encountering `UnsupportedImplementationError` or similar, try `implementation=None` to let Tokamax select the best available implementation, which can fall back to XLA. Alternatively, consult the documentation for supported hardware and data types for the chosen implementation.","message":"Specifying a particular `implementation` for a kernel (e.g., `implementation=\"mosaic\"`) can lead to exceptions if that implementation is unsupported for the given inputs (e.g., FP64 inputs) or hardware (e.g., older GPUs).","severity":"gotcha","affected_versions":"All"}],"env_vars":null,"last_verified":"2026-04-16T00:00:00.000Z","next_check":"2026-07-15T00:00:00.000Z","problems":[{"fix":"Pass `disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS` to `jax.export`. Example: `f_exported = jax.export.export(f_grad, disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS)`.","cause":"Attempting to export a JAX function containing `tokamax` kernels using `jax.export` without disabling the necessary checks.","error":"jax.errors.JAXTypeError: Custom call '...' not allowed in an exported function."},{"fix":"Either change your input data type to a supported one (e.g., `jnp.bfloat16`, `jnp.float32`), or set `implementation=None` to allow Tokamax to automatically select a compatible backend.","cause":"You explicitly requested a kernel implementation (e.g., 'mosaic') that does not support the current input data types (e.g., `jnp.float64`) or hardware configuration.","error":"tokamax.exceptions.UnsupportedImplementationError: Unsupported implementation for kernel 'layer_norm': mosaic (e.g., FP64 inputs are not supported)."},{"fix":"Update your `tokamax` library to the latest version (`pip install -U tokamax`) and consult the official GitHub repository's README or source code for the most current API. If necessary, pin a specific working version.","cause":"You are likely using an outdated API call. Tokamax is in active development, and function names or modules might have changed between minor versions.","error":"AttributeError: module 'tokamax' has no attribute 'some_function_name'"}]}