vLLM Flash Attention Wrapper
raw JSON → 2.6.2 verified Fri May 01 auth: no python
Forward-only flash-attention kernel optimized for vLLM inference. Version 2.6.2 is the latest, released as a lightweight wrapper around the Flash Attention CUDA kernel with a simplified forward-only API. Development is active alongside vLLM releases.
pip install vllm-flash-attn Common errors
error ImportError: No module named 'vllm_flash_attn' ↓
cause Package not installed or installed with an older name (flash_attn instead of vllm_flash_attn).
fix
pip install vllm-flash-attn
error RuntimeError: FlashAttention only supports CUDA with compute capability >= 8.0 ↓
cause GPU older than Ampere (e.g., V100, GTX 1080).
fix
Update to a GPU with compute capability >= 8.0, or use a CPU implementation.
error AssertionError: Input tensor must be contiguous in memory ↓
cause Input tensors are not contiguous; flash attention requires contiguous memory layout.
fix
Call .contiguous() on tensors before passing: q.contiguous(), k.contiguous(), v.contiguous()
Warnings
gotcha This package is forward-only. It does not support backward pass gradients. Using it in training will silently produce wrong gradients or crash. ↓
fix Use the full flash-attn package (flash_attn) for training.
deprecated Support for compute capability < 8.0 (e.g., V100) was dropped in v2.6.0. Older versions may still work but are unmaintained. ↓
fix Upgrade GPU to Volta or newer, or pin to vllm-flash-attn<2.6.0 if on older hardware.
gotcha The function signature for flash_attn_func changed in v2.6.0: the `softmax_scale` parameter is no longer optional and must be passed explicitly. ↓
fix Always pass softmax_scale=1.0 or your desired scaling factor.
Imports
- flash_attn_func wrong
from flash_attn import flash_attn_funccorrectfrom vllm_flash_attn import flash_attn_func - flash_attn_with_kvcache wrong
from flash_attn import flash_attn_with_kvcachecorrectfrom vllm_flash_attn import flash_attn_with_kvcache
Quickstart
import torch
from vllm_flash_attn import flash_attn_func
q = torch.randn(1, 1, 8, 64, device='cuda', dtype=torch.float16)
k = torch.randn(1, 1, 8, 64, device='cuda', dtype=torch.float16)
v = torch.randn(1, 1, 8, 64, device='cuda', dtype=torch.float16)
out = flash_attn_func(q, k, v, softmax_scale=1.0, causal=False)
print(out.shape)