Core operations for flash-linear-attention
fla-core is a Python library providing efficient, Triton-based implementations of core operations and kernels for state-of-the-art linear attention and state-space models. It serves as a minimal-dependency subset of the larger 'flash-linear-attention' project, focusing on the fundamental computational building blocks. It is currently at version 0.4.2 and follows a regular release cadence, often in conjunction with its parent project, flash-linear-attention.
Warnings
- gotcha The `fla-core` package is a minimal subset of `flash-linear-attention`. It contains core kernels and modules (e.g., in `fla.ops` and `fla.modules`) but lacks higher-level layers and models (e.g., `fla.layers`, `fla.models`). Attempting to import these high-level components with only `fla-core` installed will result in an `ImportError`.
- breaking The input tensor format for some kernels switched from 'head-first' to 'sequence-first'. This change affects how dimensions are ordered for input tensors (e.g., `(batch, heads, sequence, dim)` might become `(batch, sequence, heads, dim)`).
- gotcha fla-core heavily relies on NVIDIA Triton for its optimized kernels. Specific Triton versions (>=3.0 or nightly) and correct backend installations are required, especially for AMD ROCm or Intel XPU GPUs.
- gotcha Requires Python 3.10 or higher. Running with older Python versions will lead to installation or runtime errors.
- gotcha Users on H100 GPUs may encounter 'MMA Assertion Error' or 'LinearLayout Assertion Error' due to known Triton issues.
- gotcha PyTorch version requirement: fla-core expects PyTorch >= 2.5. Older versions may cause compatibility issues or runtime errors.
Install
-
pip install fla-core
Imports
- FusedRMSNormGated
from fla.modules import FusedRMSNormGated
- chunk_kda
from fla.ops.kda import chunk_kda
- MultiScaleRetention
from fla.layers import MultiScaleRetention
from flash_linear_attention.layers import MultiScaleRetention
Quickstart
import torch
from fla.modules import FusedRMSNormGated
import os
# fla-core operations require a CUDA-enabled GPU
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available. fla-core requires a CUDA-enabled GPU.")
device = torch.device("cuda")
# Define model parameters
hidden_size = 768
batch_size = 4
sequence_length = 512
# Initialize FusedRMSNormGated module from fla-core
norm_layer = FusedRMSNormGated(hidden_size).to(device)
# Create a dummy input tensor
input_tensor = torch.randn(batch_size, sequence_length, hidden_size, device=device, dtype=torch.float16)
# Perform a forward pass
output_tensor = norm_layer(input_tensor)
print(f"Input tensor shape: {input_tensor.shape}")
print(f"Output tensor shape: {output_tensor.shape}")
print("FusedRMSNormGated operation successful, demonstrating fla-core usage.")