Local Attention

1.11.2 · active · verified Wed Apr 15

local-attention is a Python library by lucidrains that implements local attention mechanisms with configurable windowing and lookback/lookforward options, primarily for language modeling tasks. It leverages PyTorch for efficient computation and is actively maintained with frequent minor and patch releases.

Warnings

Install

Imports

Quickstart

Initializes a `LocalAttention` module with specified windowing parameters and applies it to a dummy input tensor. This example demonstrates a causal attention setup suitable for autoregressive models.

import torch
from local_attention import LocalAttention

# Ensure reproducibility
torch.manual_seed(42)

# Define the local attention layer
# window_size defines the local neighborhood.
# look_backward=1 means each token looks at (window_size) tokens to its left.
# look_forward=0 means it does not look at tokens to its right (causal).
attn = LocalAttention(
    window_size = 512,
    look_backward = 1,
    look_forward = 0,
    dropout = 0.,
    causal = True, # Set to True for autoregressive models
    exact_windowsize = False
)

# Create a dummy input tensor: (batch, sequence_length, feature_dimension)
# For example, a batch of 1 sequence, 1024 tokens long, with 512 features per token.
x = torch.randn(1, 1024, 512)

# Apply local attention
y = attn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
# The output shape should be the same as the input shape
assert x.shape == y.shape
print("Local attention applied successfully.")

view raw JSON →