Flash Attention 4 (CUTE implementation)

raw JSON →
4.0.0b12 verified Sat May 09 auth: no python

Flash Attention 4 is the next-generation implementation of the Flash Attention algorithm using NVIDIA CUTE (CUDA Template Engine). It provides highly optimized fused attention kernels for modern GPUs, supporting head dimensions up to 256 and various data types including FP8. Version 4.0.0b12 is in beta, with frequent releases.

pip install flash-attn-4
error ModuleNotFoundError: No module named 'flash_attn_4'
cause Installed wrong package: installed 'flash-attn' (FA2/3) instead of 'flash-attn-4'.
fix
Run: pip install flash-attn-4
error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
cause Tensors not moved to GPU before calling flash_attn_func.
fix
Ensure q, k, v are all CUDA tensors: q = q.cuda() etc.
breaking Flash Attention 4 is a completely new implementation using CUTE. The API has changed; functions like flash_attn_func now return a tuple (out, lse) instead of just out.
fix Update code to unpack the tuple: out, lse = flash_attn_func(...)
gotcha The PyPI package name is 'flash-attn-4', and the Python module is 'flash_attn_4'. Do not confuse with the old 'flash-attn' package (Flash Attention 2/3).
fix Use 'pip install flash-attn-4' and 'import flash_attn_4'.
gotcha Flash Attention 4 only supports CUDA GPUs with compute capability 8.0+ (Ampere, Hopper, Blackwell). It will fail on older GPUs.
fix Check GPU compute capability via torch.cuda.get_device_capability(). Minimum 8.0 required.

Basic forward pass with causal masking.

import torch
from flash_attn_4 import flash_attn_func

q = torch.randn(1, 4, 128, 64, device='cuda', dtype=torch.float16)
k = torch.randn(1, 4, 128, 64, device='cuda', dtype=torch.float16)
v = torch.randn(1, 4, 128, 64, device='cuda', dtype=torch.float16)
out, lse = flash_attn_func(q, k, v, causal=True)
print(out.shape)