Core operations for flash-linear-attention

0.4.2 · active · verified Wed Apr 15

fla-core is a Python library providing efficient, Triton-based implementations of core operations and kernels for state-of-the-art linear attention and state-space models. It serves as a minimal-dependency subset of the larger 'flash-linear-attention' project, focusing on the fundamental computational building blocks. It is currently at version 0.4.2 and follows a regular release cadence, often in conjunction with its parent project, flash-linear-attention.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the use of a fused normalization module from `fla-core`. It initializes `FusedRMSNormGated` and applies it to a dummy tensor on a CUDA-enabled GPU. This illustrates how to integrate low-level, optimized operations provided by `fla-core`.

import torch
from fla.modules import FusedRMSNormGated
import os

# fla-core operations require a CUDA-enabled GPU
if not torch.cuda.is_available():
    raise RuntimeError("CUDA not available. fla-core requires a CUDA-enabled GPU.")

device = torch.device("cuda")

# Define model parameters
hidden_size = 768
batch_size = 4
sequence_length = 512

# Initialize FusedRMSNormGated module from fla-core
norm_layer = FusedRMSNormGated(hidden_size).to(device)

# Create a dummy input tensor
input_tensor = torch.randn(batch_size, sequence_length, hidden_size, device=device, dtype=torch.float16)

# Perform a forward pass
output_tensor = norm_layer(input_tensor)

print(f"Input tensor shape: {input_tensor.shape}")
print(f"Output tensor shape: {output_tensor.shape}")
print("FusedRMSNormGated operation successful, demonstrating fla-core usage.")

view raw JSON →