vLLM Flash Attention Wrapper

raw JSON →
2.6.2 verified Fri May 01 auth: no python

Forward-only flash-attention kernel optimized for vLLM inference. Version 2.6.2 is the latest, released as a lightweight wrapper around the Flash Attention CUDA kernel with a simplified forward-only API. Development is active alongside vLLM releases.

pip install vllm-flash-attn
error ImportError: No module named 'vllm_flash_attn'
cause Package not installed or installed with an older name (flash_attn instead of vllm_flash_attn).
fix
pip install vllm-flash-attn
error RuntimeError: FlashAttention only supports CUDA with compute capability >= 8.0
cause GPU older than Ampere (e.g., V100, GTX 1080).
fix
Update to a GPU with compute capability >= 8.0, or use a CPU implementation.
error AssertionError: Input tensor must be contiguous in memory
cause Input tensors are not contiguous; flash attention requires contiguous memory layout.
fix
Call .contiguous() on tensors before passing: q.contiguous(), k.contiguous(), v.contiguous()
gotcha This package is forward-only. It does not support backward pass gradients. Using it in training will silently produce wrong gradients or crash.
fix Use the full flash-attn package (flash_attn) for training.
deprecated Support for compute capability < 8.0 (e.g., V100) was dropped in v2.6.0. Older versions may still work but are unmaintained.
fix Upgrade GPU to Volta or newer, or pin to vllm-flash-attn<2.6.0 if on older hardware.
gotcha The function signature for flash_attn_func changed in v2.6.0: the `softmax_scale` parameter is no longer optional and must be passed explicitly.
fix Always pass softmax_scale=1.0 or your desired scaling factor.

Basic usage of forward-only flash attention. Requires CUDA GPU.

import torch
from vllm_flash_attn import flash_attn_func

q = torch.randn(1, 1, 8, 64, device='cuda', dtype=torch.float16)
k = torch.randn(1, 1, 8, 64, device='cuda', dtype=torch.float16)
v = torch.randn(1, 1, 8, 64, device='cuda', dtype=torch.float16)
out = flash_attn_func(q, k, v, softmax_scale=1.0, causal=False)
print(out.shape)