Flash Attention

2.8.3 · active · verified Sun Apr 12

Flash Attention is a fast and memory-efficient exact attention mechanism for deep learning models, particularly Transformers. It reorders the attention computation to reduce the number of memory accesses, making it significantly faster and less memory-intensive than standard attention. The library is currently stable at version 2.8.3, with an active beta development for version 4.0.0 which introduces new features and architectural changes. Its release cadence is driven by research advancements and performance optimizations.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `flash_attn_func` with separate query, key, and value tensors. It highlights the importance of data types (float16/bfloat16) and device placement (CUDA) for optimal performance. The `causal=True` argument is common for generative models. Ensure your `head_dim` is a multiple of 8 and ideally no more than 256.

import torch
from flash_attn import flash_attn_func

# Example for q, k, v as separate tensors
batch_size = 2
seq_len = 128
num_heads = 8
head_dim = 64 # Must be multiple of 8, typically <= 256

dtype = torch.float16 # FlashAttention works best with float16 or bfloat16
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cpu':
    print("Warning: FlashAttention is primarily designed for CUDA GPUs.")

q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)

# Causal attention (for language models)
output = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True)

print("Output shape:", output.shape)
print("Output device:", output.device)

view raw JSON →