Flash Linear Attention

0.4.2 · active · verified Wed Apr 15

Flash Linear Attention (FLA) is a Python library providing efficient, Triton-based implementations for state-of-the-art linear attention models and emerging sequence modeling architectures. It aims for high-performance training and inference across NVIDIA, AMD, and Intel GPUs. As of version 0.4.2, the library is actively maintained with frequent releases, offering optimized kernels, fused modules, and integration-ready layers for PyTorch and Hugging Face models.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize and use a `MultiScaleRetention` layer from `flash-linear-attention` with a sample PyTorch tensor. It's crucial to run this on a CUDA-enabled GPU for the performance benefits of Triton kernels.

import torch
from fla.layers import MultiScaleRetention

# Example input tensor (batch_size, sequence_length, hidden_dim)
batch_size = 2
sequence_length = 512
hidden_dim = 128

# Ensure CUDA is available and tensors are on GPU for optimal performance
if torch.cuda.is_available():
    input_tensor = torch.randn(batch_size, sequence_length, hidden_dim).cuda()
    # Initialize the MultiScaleRetention layer
    # d_model should match hidden_dim, num_heads defines the number of attention heads
    model = MultiScaleRetention(d_model=hidden_dim, num_heads=4).cuda()
    
    # Forward pass
    output_tensor = model(input_tensor)
    
    print(f"Input shape: {input_tensor.shape}")
    print(f"Output shape: {output_tensor.shape}")
else:
    print("CUDA is not available. Please ensure a compatible GPU and PyTorch installation.")
    print("Tensors and model should be moved to GPU for Flash Linear Attention.")

view raw JSON →