Local Attention
local-attention is a Python library by lucidrains that implements local attention mechanisms with configurable windowing and lookback/lookforward options, primarily for language modeling tasks. It leverages PyTorch for efficient computation and is actively maintained with frequent minor and patch releases.
Warnings
- gotcha Misunderstanding the interplay between `window_size`, `look_backward`, `look_forward`, and `causal` can lead to unintended attention patterns or incorrect information flow. For instance, `causal=True` combined with `look_forward > 0` might not behave as expected for strict autoregression.
- gotcha As a PyTorch-based library, ensuring input tensors are on the correct device (CPU/GPU) and have compatible data types (`torch.float32`, `torch.float16`) is crucial. Mismatches frequently cause runtime errors or significantly degraded performance.
- gotcha The library expects input tensors of shape `(batch, sequence_length, feature_dimension)`. Incorrectly shaped inputs, particularly transposing `sequence_length` and `feature_dimension`, are a common source of `RuntimeError` or `ValueError`.
- gotcha Version 1.11.0 updated the internal `look_around()` function to use native PyTorch functionality. While not a public API breaking change, it represents a significant internal optimization. If you were relying on previous internal behaviors (e.g., via subclassing or monkey-patching), this change could affect your custom logic.
Install
-
pip install local-attention
Imports
- LocalAttention
from local_attention import LocalAttention
Quickstart
import torch
from local_attention import LocalAttention
# Ensure reproducibility
torch.manual_seed(42)
# Define the local attention layer
# window_size defines the local neighborhood.
# look_backward=1 means each token looks at (window_size) tokens to its left.
# look_forward=0 means it does not look at tokens to its right (causal).
attn = LocalAttention(
window_size = 512,
look_backward = 1,
look_forward = 0,
dropout = 0.,
causal = True, # Set to True for autoregressive models
exact_windowsize = False
)
# Create a dummy input tensor: (batch, sequence_length, feature_dimension)
# For example, a batch of 1 sequence, 1024 tokens long, with 512 features per token.
x = torch.randn(1, 1024, 512)
# Apply local attention
y = attn(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
# The output shape should be the same as the input shape
assert x.shape == y.shape
print("Local attention applied successfully.")