Flash Attention
Flash Attention is a fast and memory-efficient exact attention mechanism for deep learning models, particularly Transformers. It reorders the attention computation to reduce the number of memory accesses, making it significantly faster and less memory-intensive than standard attention. The library is currently stable at version 2.8.3, with an active beta development for version 4.0.0 which introduces new features and architectural changes. Its release cadence is driven by research advancements and performance optimizations.
Warnings
- breaking The API for `flash_attn_func` has changed significantly between v1, v2, and the v4 beta, including argument order, default values, and added parameters (e.g., `softmax_scale`, `dropout_p`, `causal`, different return values). Code written for v1 or early v2 will likely break on later v2 or v4.
- gotcha Flash Attention requires a specific CUDA architecture (SM70+ for v1/v2, SM80+ for v2.2+, SM90+ for v4 beta) and specific `head_dim` values. Typically, `head_dim` must be a multiple of 8 (e.g., 64, 128, 256) and for optimal performance, should not exceed 256. Using unsupported `head_dim` or CUDA architecture will result in runtime errors or fallbacks to slower implementations.
- gotcha Installation can be sensitive to your PyTorch and CUDA setup. Using `pip install flash-attn` without `--no-build-isolation` can lead to `flash-attn` compiling against a different CUDA toolkit than your PyTorch installation, causing runtime errors or crashes.
- breaking The FlashAttention v4 beta introduces new APIs and internal architecture changes. Code written for v2.x is NOT directly compatible with the v4 beta, and vice versa. Key changes include a redesigned `FlashAttention2` module and `flash_attn_func` with updated arguments to support new features like dynamic sequence lengths.
Install
-
pip install flash-attn --no-build-isolation -
pip install flash-attn --no-cuda-extensions
Imports
- flash_attn_func
from flash_attn import flash_attn_func
- flash_attn_qkvpacked_func
from flash_attn import flash_attn_qkvpacked_func
- flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func
- FlashAttention2
from flash_attn.modules.mha import FlashAttention2
Quickstart
import torch
from flash_attn import flash_attn_func
# Example for q, k, v as separate tensors
batch_size = 2
seq_len = 128
num_heads = 8
head_dim = 64 # Must be multiple of 8, typically <= 256
dtype = torch.float16 # FlashAttention works best with float16 or bfloat16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cpu':
print("Warning: FlashAttention is primarily designed for CUDA GPUs.")
q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
# Causal attention (for language models)
output = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True)
print("Output shape:", output.shape)
print("Output device:", output.device)