Rotary Embedding for PyTorch
This library provides a Pytorch implementation of the Rotary Positional Embedding (RoPE), a crucial component for modern transformer architectures like LLaMA, designed to improve the model's ability to handle long sequences. It offers an easy-to-use API to apply rotary embeddings to query and key tensors. The current version is 0.8.9, and it follows a rapid release cadence for bug fixes and minor improvements.
Warnings
- gotcha Prior to v0.8.0, there was a bug in the chi scale multiplication which could lead to incorrect positional embeddings, particularly when using specific scaling factors. Models trained with affected versions might show subtle performance degradation or instability.
- gotcha When using `torch.compile` for performance optimization, versions prior to v0.8.6 might encounter issues due to the `seq_len` being cached as a non-integer type. This could lead to compilation failures or incorrect behavior with JIT.
- gotcha The `RotaryEmbedding` object caches internal calculations based on the maximum sequence length encountered. If inputs change drastically in `seq_len` (e.g., during inference with varying sequence lengths) or if `max_seq_len` is not adequately pre-configured, it might lead to unnecessary re-computations or out-of-bounds errors if the input `seq_len` exceeds the initially cached maximum.
Install
-
pip install rotary-embedding-torch
Imports
- RotaryEmbedding
from rotary_embedding_torch import RotaryEmbedding
Quickstart
import torch
from rotary_embedding_torch import RotaryEmbedding
# Define embedding dimension
dim = 64
# Initialize RotaryEmbedding. max_seq_len can be set for pre-computation.
# If not set, it's computed dynamically based on input.
rotary_emb = RotaryEmbedding(dim=dim, max_seq_len=2048)
# Create dummy query and key tensors
# shape: (batch_size, num_heads, sequence_length, head_dim)
seq_len = 1024
q = torch.randn(1, 8, seq_len, dim)
k = torch.randn(1, 8, seq_len, dim)
# Apply rotary embeddings
q_rot = rotary_emb(q)
k_rot = rotary_emb(k)
print(f"Original query shape: {q.shape}")
print(f"Rotary-embedded query shape: {q_rot.shape}")
print(f"Example embedded value (first element): {q_rot[0, 0, 0, 0].item():.4f}")