Rotary Embedding for PyTorch

0.8.9 · active · verified Sun Apr 12

This library provides a Pytorch implementation of the Rotary Positional Embedding (RoPE), a crucial component for modern transformer architectures like LLaMA, designed to improve the model's ability to handle long sequences. It offers an easy-to-use API to apply rotary embeddings to query and key tensors. The current version is 0.8.9, and it follows a rapid release cadence for bug fixes and minor improvements.

Warnings

Install

Imports

Quickstart

This example demonstrates how to initialize `RotaryEmbedding` and apply it to example query and key tensors, typically used in self-attention mechanisms. The `dim` parameter must match the last dimension of your input tensors (head_dim).

import torch
from rotary_embedding_torch import RotaryEmbedding

# Define embedding dimension
dim = 64
# Initialize RotaryEmbedding. max_seq_len can be set for pre-computation.
# If not set, it's computed dynamically based on input.
rotary_emb = RotaryEmbedding(dim=dim, max_seq_len=2048)

# Create dummy query and key tensors
# shape: (batch_size, num_heads, sequence_length, head_dim)
seq_len = 1024
q = torch.randn(1, 8, seq_len, dim)
k = torch.randn(1, 8, seq_len, dim)

# Apply rotary embeddings
q_rot = rotary_emb(q)
k_rot = rotary_emb(k)

print(f"Original query shape: {q.shape}")
print(f"Rotary-embedded query shape: {q_rot.shape}")
print(f"Example embedded value (first element): {q_rot[0, 0, 0, 0].item():.4f}")

view raw JSON →