Triton
Triton is a language and compiler for writing highly efficient custom Deep Learning operations. It provides a Python-based programming environment for writing custom GPU kernels that can achieve performance on par with hand-tuned CUDA, but with higher productivity and flexibility than other existing DSLs. Triton aims to bridge the gap between high-level deep learning frameworks and low-level GPU programming. The current version is 3.6.0, with frequent releases (multiple major/minor releases per year).
Warnings
- breaking Triton 3.4.0 dropped support for Python 3.8. The minimum required Python version is now 3.10, and it supports up to 3.14 (i.e., <3.15). Ensure your Python environment meets these requirements.
- breaking In Triton 3.0.0, the behavior of `tl.constexpr` changed. You can no longer directly call non-Triton functions (e.g., `math.log2`) within a JIT function and assign their results to `tl.constexpr` variables. These values must be pre-computed outside the kernel or implemented with `triton.language` equivalents.
- gotcha Triton primarily supports Linux with NVIDIA GPUs (Compute Capability 7.0 or higher, Volta generation or newer). AMD GPU support is in development. Official Windows and macOS binaries are not provided; WSL2 is the recommended workaround for Windows. An up-to-date NVIDIA driver is critical for PTX JIT compilation. Support for NVIDIA GPUs with Turing architecture (sm75, e.g., GTX 16xx/RTX 20xx) was dropped starting from Triton 3.3.
- gotcha Triton 3.5.0 introduced a bug that broke `sm103` (NVIDIA GB200/GB300) support. This was quickly patched in the 3.5.1 bug fix release.
- gotcha The official Triton library currently restricts `fp8` (float8) data type support to NVIDIA GPUs with compute capability >= 8.9 (e.g., RTX 40xx and newer). It is not officially supported on Ampere (RTX 30xx) or older architectures.
- gotcha Triton stores cache files in `~/.triton` by default. This can lead to conflicts or unexpected behavior when using different versions or forks of Triton, or when building self-contained applications. There are currently no official environment variables to override all cache-related directories.
Install
-
pip install triton
Imports
- triton
import triton
- triton.language as tl
import triton.language as tl
- triton.jit
@triton.jit
Quickstart
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# Map program_id to a block of elements
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to handle out-of-bounds accesses
mask = offsets < n_elements
# Load data from memory
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# Perform addition
output = x + y
# Write back to memory
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = x.numel()
# The block size is a compile-time constant, so we can't use `n_elements`
# directly. Instead, we use a heuristic to choose a good block size.
BLOCK_SIZE = 1024 # Or adjust based on your needs
# Number of programs (blocks) to launch
grid = lambda META: (triton.cdiv(n_elements, META['BLOCK_SIZE']),)
# Launch the kernel
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE)
return output
if __name__ == "__main__":
# Example usage with PyTorch tensors
size = 4096
x = torch.randn(size, device='cuda')
y = torch.randn(size, device='cuda')
output_triton = add(x, y)
output_torch = x + y
print(f"Triton output matches PyTorch: {torch.allclose(output_triton, output_torch)}")