TileLang - High-Performance Kernel Development DSL

0.1.8 · active · verified Thu Apr 16

TileLang (tile-lang) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU/accelerator kernels, such as GEMM, Dequant GEMM, and FlashAttention. It provides a Pythonic syntax with an underlying compiler infrastructure built on Apache TVM, allowing developers to focus on productivity while achieving state-of-the-art performance. The library is actively developed, with frequent updates and nightly builds, currently at version 0.1.8.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define and execute a matrix multiplication (GEMM) kernel using TileLang, integrating with PyTorch for tensor management and validation. It showcases decorators like `@tilelang.jit` and `@T.prim_func`, memory allocation with `T.alloc_shared`, data movement with `T.copy`, matrix multiplication with `T.gemm`, and loop pipelining with `T.Pipelined`.

import tilelang
import tilelang.language as T
import torch

@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32):
    @T.prim_func
    def main(
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((K, N), dtype),
        C: T.Tensor((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)

            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
                T.copy(A[by * block_M, ko * block_K], A_shared)
                T.copy(B[ko * block_K, bx * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local)

            T.copy(C_local, C[by * block_M, bx * block_N])
    return main

M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 64

# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_kernel = matmul(M, N, K, block_M, block_N, block_K)

# 2. Test the kernel in Python with PyTorch data
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)

# Run the kernel
matmul_kernel(a, b, c)

# Reference multiplication using PyTorch
ref_c = (a @ b).to(c.dtype)

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")

# (Optional) Profile latency with kernel
# profiler = matmul_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
# latency = profiler.do_bench()
# print(f"Latency: {latency} ms")

view raw JSON →