Flash Linear Attention
Flash Linear Attention (FLA) is a Python library providing efficient, Triton-based implementations for state-of-the-art linear attention models and emerging sequence modeling architectures. It aims for high-performance training and inference across NVIDIA, AMD, and Intel GPUs. As of version 0.4.2, the library is actively maintained with frequent releases, offering optimized kernels, fused modules, and integration-ready layers for PyTorch and Hugging Face models.
Warnings
- breaking Starting from v0.3.2, the `flash-linear-attention` package was split into `fla-core` (minimal dependencies) and `flash-linear-attention` (extension, including `fla/layers` and `fla/models`, depending on `transformers`). Users upgrading from older versions or relying on direct `fla.ops` imports may experience changes in dependency management or module resolution.
- breaking In November 2024, the input tensor format was switched from 'head-first' to 'sequence-first'. This means the expected dimension order for input tensors to attention layers has changed.
- gotcha Strict compatibility between PyTorch and Triton versions is required. Using incompatible versions can lead to `AttributeError` (e.g., `'NoneType' object has no attribute 'start'`) or `LinearLayout Assertion Error`. This is especially relevant for nightly builds or specific hardware (like ARM).
- gotcha For AMD and Intel GPUs, specific Triton ROCm or XPU backends are required, which might need separate installation steps beyond `pip install triton`. Without the correct backend, performance will be severely impacted or the library may not function.
- gotcha The library explicitly requires Python 3.10 or newer. Older Python versions can lead to `AttributeError: 'NoneType' object has no attribute 'start'` during Triton kernel compilation.
- deprecated The external `causal-conv1d` library is no longer a required dependency as `flash-linear-attention` now provides its own Triton implementations for `conv1d` operations.
Install
-
pip install flash-linear-attention -
pip install torch triton einops transformers numpy # For AMD GPUs, ensure Triton ROCm backend is installed separately. # For Intel GPUs, ensure Triton XPU backend is installed separately.
Imports
- MultiScaleRetention
from fla.layers import MultiScaleRetention
- FlashMamba
from fla.models import FlashMamba
Quickstart
import torch
from fla.layers import MultiScaleRetention
# Example input tensor (batch_size, sequence_length, hidden_dim)
batch_size = 2
sequence_length = 512
hidden_dim = 128
# Ensure CUDA is available and tensors are on GPU for optimal performance
if torch.cuda.is_available():
input_tensor = torch.randn(batch_size, sequence_length, hidden_dim).cuda()
# Initialize the MultiScaleRetention layer
# d_model should match hidden_dim, num_heads defines the number of attention heads
model = MultiScaleRetention(d_model=hidden_dim, num_heads=4).cuda()
# Forward pass
output_tensor = model(input_tensor)
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")
else:
print("CUDA is not available. Please ensure a compatible GPU and PyTorch installation.")
print("Tensors and model should be moved to GPU for Flash Linear Attention.")