FlashInfer: Kernel Library for LLM Serving

0.6.7.post3 · active · verified Thu Apr 09

FlashInfer is a high-performance kernel library for optimizing Large Language Model (LLM) inference on NVIDIA GPUs. It provides efficient CUDA kernels for operations like paged attention, prefill, and decode. Currently at version 0.6.7.post3, the library is under active development with frequent patch releases and nightly builds, indicating rapid evolution and potential API changes.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to set up `PagedKVCache` and use `BatchDecodeWithPagedKVCache` to perform a single-token decode operation. It simulates a sequence already present in the cache and then processes a new query token, highlighting the typical workflow for LLM inference.

import torch
import flashinfer as fi

# Ensure CUDA is available
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. FlashInfer requires a CUDA-enabled GPU.")

# Device and dtype
device = "cuda"
dtype = torch.float16

# Model parameters (simplified for example)
num_layers = 1 # In real models, usually multiple
num_heads = 32
kv_heads = 32
head_dim = 128
page_size = 16
max_total_seq_len = 2048 # Max tokens in KV cache across all sequences

# 1. Initialize PagedKVCache
# This manages the memory for key/value states on the GPU
kv_cache = fi.core.PagedKVCache(
    num_layers=num_layers,
    num_kv_heads=kv_heads,
    head_dim=head_dim,
    page_size=page_size,
    max_num_pages=max_total_seq_len // page_size,
    device=device,
    data_type=dtype,
)

# 2. Create a BatchDecodeWithPagedKVCache wrapper
# This object prepares the inputs for the underlying attention kernels
decode_wrapper = fi.BatchDecodeWithPagedKVCache(
    kv_cache=kv_cache,
    num_heads=num_heads,
    kv_heads=kv_heads,
    head_dim=head_dim,
    sm_scale=1.0 / (head_dim**0.5), # Standard attention scale
    dtype=dtype,
)

# 3. Simulate adding a sequence to the cache (prefill step)
# In a real LLM serving scenario, this would populate the cache with initial tokens.
batch_size_decode = 1 # Decoding one sequence
prefill_len = 50 # Length of the sequence already in cache

# Allocate pages for a new sequence of `prefill_len`
seq_idx_in_batch = kv_cache.begin_forward(prefill_len)
# In a real application, you'd populate kv_cache with actual K/V from a prefill operation.
# For this example, we just simulate the cache being 'ready' for decode.
kv_cache.end_forward(seq_idx_in_batch, prefill_len) # Commits the pages, making the sequence ready for decode.

# 4. Prepare query tensor for decode
# Query for the next token, shape (batch_size, 1, num_heads, head_dim)
query_decode = torch.randn(batch_size_decode, 1, num_heads, head_dim, dtype=dtype, device=device)

# 5. Perform the decode operation for the next token
output = decode_wrapper.decode(query_decode)

print(f"FlashInfer BatchDecode output shape: {output.shape}")
print("FlashInfer decode successful.")

view raw JSON →