CoLT5 Attention

raw JSON →
0.11.1 verified Sat May 09 auth: no python

Conditionally Routed Attention, an implementation of the CoLT5 architecture for efficient long-context transformers. Current version 0.11.1, rapid release cadence.

pip install colt5-attention
error AttributeError: module 'colt5_attention' has no attribute 'CoLT5Attention'
cause Wrong import path; user tried `import colt5_attention` then `colt5_attention.CoLT5Attention` but the class may not be top-level in older versions.
fix
Use from colt5_attention import CoLT5Attention
error RuntimeError: Expected all tensors to be on the same device, but found at least two devices
cause Passing tensors on different devices (e.g., CPU and CUDA) to the attention module.
fix
Ensure all inputs are on the same device: x = x.to('cuda'), mask = mask.to('cuda')
breaking Version 0.11.0 changed default value of `num_routed_queries` from 128 to 64. Existing code relying on default may see different memory/performance.
fix Explicitly set `num_routed_queries=128` if you need the old behavior.
gotcha Attention masking must use a float mask (e.g., 0.0 for keep, -inf for mask). Boolean masks are not supported and may silently produce wrong outputs.
fix Convert boolean mask to float: `attn_mask = 0.0 * keep_mask + (-1e9) * (~keep_mask)`
gotcha Input tensor must be contiguous. Non-contiguous tensors can cause runtime error or incorrect gradient propagation.
fix Call `x = x.contiguous()` before passing to attention.

Initialize CoLT5 attention with routed queries/key-values and run forward pass.

import torch
from colt5_attention import CoLT5Attention

attn = CoLT5Attention(
    dim=512,
    num_routed_queries=64,
    num_routed_key_values=64,
    num_heads=8,
    dropout=0.1
)
x = torch.randn(2, 1024, 512)
out = attn(x)
print(out.shape)  # (2, 1024, 512)