XFormers

0.0.35 · active · verified Thu Apr 09

XFormers is a PyTorch-based library providing a collection of composable, optimized building blocks for Transformer models. It aims to accelerate deep learning research by offering flexible and highly efficient components, including advanced attention mechanisms and fused operations that often outperform native PyTorch implementations in terms of speed and memory usage. Actively developed by Meta Platforms, Inc., the library frequently releases updates, with the current stable version being 0.0.35.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `xformers.ops.memory_efficient_attention` with dummy PyTorch tensors for both standard and causal attention patterns. It highlights the typical tensor shape and the common practice of using half-precision floating-point numbers (float16) for performance on GPUs. The `xformers.info` utility is also mentioned for diagnostics. Ensure PyTorch and CUDA are properly installed and configured.

import torch
from xformers.ops import memory_efficient_attention, LowerTriangularMask

# Ensure tensors are on CUDA if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Assume batch_size=2, seq_len=128, num_heads=8, head_dim=64
batch_size = 2
seq_len = 128
num_heads = 8
head_dim = 64

# Create dummy query, key, value tensors
# xFormers memory_efficient_attention typically expects (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device)
key = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device)
value = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device)

# It's common to use float16 (half precision) for performance with xFormers
query = query.half()
key = key.half()
value = value.half()

# Example 1: Standard memory-efficient attention
# xFormers automatically dispatches to the best available operator
output_attn = memory_efficient_attention(query, key, value)
print(f"Output attention shape (standard): {output_attn.shape}")

# Example 2: Causal attention with a lower triangular mask
# Note: The attn_bias argument structure has changed in newer versions (e.g., v0.0.21+)
# For LowerTriangularMask, it often handles internal expansion if num_heads is implicitly available.
attn_bias = LowerTriangularMask()
output_causal_attn = memory_efficient_attention(query, key, value, attn_bias=attn_bias)
print(f"Output attention shape (causal): {output_causal_attn.shape}")

# To verify installation and available kernels:
# import subprocess
# subprocess.run(["python", "-m", "xformers.info"])

view raw JSON →