x-transformers

2.17.9 · active · verified Sat Apr 11

x-transformers is a concise yet fully-featured PyTorch library for attention-based transformers, offering a collection of promising experimental features and architectures derived from recent research papers. Maintained by lucidrains, it focuses on integrating cutting-edge advancements. The library is currently at version 2.17.9 and receives frequent updates, reflecting its experimental and research-oriented nature.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to set up a basic decoder-only (GPT-like) transformer model using `TransformerWrapper` and `Decoder`. It initializes a model with a specified vocabulary size, sequence length, and decoder attention layer configuration, then runs a sample forward pass.

import torch
from x_transformers import TransformerWrapper, Decoder
import os

# Example for a decoder-only (GPT-like) model
# Note: .cuda() calls are for GPU usage; remove if running on CPU only.
model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8
    )
)

# Move model to GPU if available
if torch.cuda.is_available():
    model = model.cuda()
    x = torch.randint(0, 256, (1, 1024)).cuda()
else:
    x = torch.randint(0, 256, (1, 1024))

output = model(x)
print(f"Output shape: {output.shape}")

view raw JSON →