FlashInfer: Kernel Library for LLM Serving
FlashInfer is a high-performance kernel library for optimizing Large Language Model (LLM) inference on NVIDIA GPUs. It provides efficient CUDA kernels for operations like paged attention, prefill, and decode. Currently at version 0.6.7.post3, the library is under active development with frequent patch releases and nightly builds, indicating rapid evolution and potential API changes.
Warnings
- gotcha FlashInfer is a kernel library requiring an NVIDIA GPU with a compatible CUDA runtime. It will not work on CPUs or other accelerators. Using pre-built wheels requires a matching CUDA toolkit version (e.g., cu118 for CUDA 11.8); a mismatch often leads to `RuntimeError` or `ModuleNotFoundError`.
- breaking The library is under active development and not yet at a 1.0 release. Frequent minor and patch releases (including nightly builds) may introduce API changes or breaking modifications to function signatures and class constructors.
- gotcha FlashInfer's API, particularly for `PagedKVCache` and attention wrappers, is relatively low-level. Incorrect setup of internal metadata, page tables, or buffer management can lead to subtle bugs, incorrect attention calculations, or memory access violations.
- gotcha FlashInfer is tightly coupled with PyTorch for tensor operations and device management. While `torch` is a dependency, ensure your PyTorch version is compatible, especially when using specific CUDA versions or pre-built FlashInfer wheels.
Install
-
pip install flashinfer-python -
pip install flashinfer-python --pre --extra-index-url https://flashinfer.ai/whl/cu121
Imports
- flashinfer
import flashinfer as fi
- BatchDecodeWithPagedKVCache
from flashinfer import BatchDecodeWithPagedKVCache
- BatchPrefillWithRaggedKVCache
from flashinfer import BatchPrefillWithRaggedKVCache
- PagedKVCache
from flashinfer.core import PagedKVCache
Quickstart
import torch
import flashinfer as fi
# Ensure CUDA is available
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. FlashInfer requires a CUDA-enabled GPU.")
# Device and dtype
device = "cuda"
dtype = torch.float16
# Model parameters (simplified for example)
num_layers = 1 # In real models, usually multiple
num_heads = 32
kv_heads = 32
head_dim = 128
page_size = 16
max_total_seq_len = 2048 # Max tokens in KV cache across all sequences
# 1. Initialize PagedKVCache
# This manages the memory for key/value states on the GPU
kv_cache = fi.core.PagedKVCache(
num_layers=num_layers,
num_kv_heads=kv_heads,
head_dim=head_dim,
page_size=page_size,
max_num_pages=max_total_seq_len // page_size,
device=device,
data_type=dtype,
)
# 2. Create a BatchDecodeWithPagedKVCache wrapper
# This object prepares the inputs for the underlying attention kernels
decode_wrapper = fi.BatchDecodeWithPagedKVCache(
kv_cache=kv_cache,
num_heads=num_heads,
kv_heads=kv_heads,
head_dim=head_dim,
sm_scale=1.0 / (head_dim**0.5), # Standard attention scale
dtype=dtype,
)
# 3. Simulate adding a sequence to the cache (prefill step)
# In a real LLM serving scenario, this would populate the cache with initial tokens.
batch_size_decode = 1 # Decoding one sequence
prefill_len = 50 # Length of the sequence already in cache
# Allocate pages for a new sequence of `prefill_len`
seq_idx_in_batch = kv_cache.begin_forward(prefill_len)
# In a real application, you'd populate kv_cache with actual K/V from a prefill operation.
# For this example, we just simulate the cache being 'ready' for decode.
kv_cache.end_forward(seq_idx_in_batch, prefill_len) # Commits the pages, making the sequence ready for decode.
# 4. Prepare query tensor for decode
# Query for the next token, shape (batch_size, 1, num_heads, head_dim)
query_decode = torch.randn(batch_size_decode, 1, num_heads, head_dim, dtype=dtype, device=device)
# 5. Perform the decode operation for the next token
output = decode_wrapper.decode(query_decode)
print(f"FlashInfer BatchDecode output shape: {output.shape}")
print("FlashInfer decode successful.")