Pre-compiled cubins for FlashInfer

0.6.7.post3 · active · verified Sun Apr 12

FlashInfer-cubin provides pre-compiled kernel binaries for FlashInfer, supporting a wide range of GPU architectures. This optional package for `flashinfer-python` eliminates JIT compilation and downloading overhead at runtime, leading to faster initialization and enabling offline usage. The FlashInfer project focuses on delivering high-performance LLM GPU kernels for serving and inference, maintaining an active development cycle with frequent nightly builds and regular patch releases.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the usage of the core `flashinfer` library for a single-request decode attention operation. When `flashinfer-cubin` is installed, it transparently provides pre-compiled CUDA kernels to `flashinfer-python`, significantly speeding up operations like this by avoiding runtime compilation overhead.

import torch
import flashinfer

# Example of FlashInfer's single-request decode attention
# (flashinfer-cubin provides the underlying kernels for optimal performance)

kv_len = 2048
num_kv_heads = 32
head_dim = 128

q = torch.randn(1, head_dim, dtype=torch.float16, device='cuda')
k_tensor = torch.randn(kv_len, num_kv_heads, head_dim, dtype=torch.float16, device='cuda')
v_tensor = torch.randn(kv_len, num_kv_heads, head_dim, dtype=torch.float16, device='cuda')

# Prepare FlashInfer attention wrapper
wrapper = flashinfer.to_flashinfer_decode_wrapper(
    kv_len,
    num_kv_heads,
    head_dim,
    0 # page_size, use 0 for single request
)

# Allocate KV cache
k_cache, v_cache = wrapper.alloc_kv_cache(torch.float16, device='cuda')

# Append K/V to cache (simulates historical tokens)
wrapper.begin_forward(k_cache, v_cache)
wrapper.end_forward()

# Perform decode attention
output = flashinfer.single_decode_with_kv_cache(
    q,
    k_cache,
    v_cache,
    wrapper.kv_layout,
    wrapper.kv_indices,
    wrapper.kv_indptr,
    wrapper.last_page_len,
    num_kv_heads,
    num_kv_heads, # num_query_heads == num_kv_heads for single decode
    head_dim,
    True # casual
)

print(output.shape)

view raw JSON →