x-transformers
x-transformers is a concise yet fully-featured PyTorch library for attention-based transformers, offering a collection of promising experimental features and architectures derived from recent research papers. Maintained by lucidrains, it focuses on integrating cutting-edge advancements. The library is currently at version 2.17.9 and receives frequent updates, reflecting its experimental and research-oriented nature.
Warnings
- gotcha The library is in 'Beta' development status (Development Status :: 4 - Beta) and often integrates experimental features from recent research papers. This means API stability and feature behavior may change rapidly between versions, and some features might be experimental or less thoroughly tested than in more mature libraries.
- gotcha When configuring embedding normalization, it's recommended to use either `l2norm_embed` or `post_emb_norm`, but not both simultaneously, as they are designed to serve similar purposes and using both might lead to redundant or conflicting behavior.
- gotcha Some advanced or experimental features, such as 'Rezero Is All You Need' (as noted in an older GitHub issue), might exhibit stability or convergence issues (e.g., producing NaN values) depending on the specific use case, dataset, and hyperparameter tuning.
Install
-
pip install x-transformers
Imports
- TransformerWrapper
from x_transformers import TransformerWrapper
- Decoder
from x_transformers import Decoder
- Encoder
from x_transformers import Encoder
- XTransformer
from x_transformers import XTransformer
Quickstart
import torch
from x_transformers import TransformerWrapper, Decoder
import os
# Example for a decoder-only (GPT-like) model
# Note: .cuda() calls are for GPU usage; remove if running on CPU only.
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8
)
)
# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
x = torch.randint(0, 256, (1, 1024)).cuda()
else:
x = torch.randint(0, 256, (1, 1024))
output = model(x)
print(f"Output shape: {output.shape}")