Triton

3.6.0 · active · verified Sat Mar 28

Triton is a language and compiler for writing highly efficient custom Deep Learning operations. It provides a Python-based programming environment for writing custom GPU kernels that can achieve performance on par with hand-tuned CUDA, but with higher productivity and flexibility than other existing DSLs. Triton aims to bridge the gap between high-level deep learning frameworks and low-level GPU programming. The current version is 3.6.0, with frequent releases (multiple major/minor releases per year).

Warnings

Install

Imports

Quickstart

This quickstart demonstrates a basic vector addition kernel using Triton. It shows how to define a JIT-compiled kernel with `@triton.jit`, load and store data using `triton.language` primitives like `tl.load` and `tl.store`, and launch the kernel from Python with a specified grid size. This example processes elements in blocks, illustrating Triton's approach to GPU parallelism.

import triton
import triton.language as tl
import torch

@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    # Map program_id to a block of elements
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Create a mask to handle out-of-bounds accesses
    mask = offsets < n_elements

    # Load data from memory
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # Perform addition
    output = x + y

    # Write back to memory
    tl.store(output_ptr + offsets, output, mask=mask)


def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = x.numel()

    # The block size is a compile-time constant, so we can't use `n_elements`
    # directly. Instead, we use a heuristic to choose a good block size.
    BLOCK_SIZE = 1024 # Or adjust based on your needs

    # Number of programs (blocks) to launch
    grid = lambda META: (triton.cdiv(n_elements, META['BLOCK_SIZE']),)

    # Launch the kernel
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE)
    return output

if __name__ == "__main__":
    # Example usage with PyTorch tensors
    size = 4096
    x = torch.randn(size, device='cuda')
    y = torch.randn(size, device='cuda')

    output_triton = add(x, y)
    output_torch = x + y

    print(f"Triton output matches PyTorch: {torch.allclose(output_triton, output_torch)}")

view raw JSON →