torchscale

raw JSON →
0.3.0 verified Sat May 09 auth: no python

torchscale is a PyTorch library for building large-scale Transformer models, providing components like Multi-head Attention (MHA), Long Short-Term Memory (LSTM), and other scalable architectures. As of version 0.3.0, it supports Python >=3.8 and is maintained by Microsoft. Releases are infrequent.

pip install torchscale
error ModuleNotFoundError: No module named 'torchscale'
cause torchscale is not installed or installed in a different environment.
fix
Run pip install torchscale from the correct Python environment.
error AttributeError: module 'torchscale' has no attribute 'MHA'
cause Importing from wrong path: top-level module does not contain MHA.
fix
Use from torchscale.component import MHA.
error RuntimeError: The expanded size of the tensor must match the existing size at non-singleton dimension
cause Input tensor dimensions do not match MHA expectations, often due to sequence-first vs batch-first confusion.
fix
Ensure input shape is (batch, seq_len, embed_dim). Use x = x.transpose(0,1) if using (seq_len, batch, embed_dim).
deprecated The `torchscale.model.LongShortTerm` class is deprecated in 0.3.0; use `IncrementalDecoder` or `TemporalDecoder` instead.
fix Replace `LongShortTerm` with `IncrementalDecoder` or `TemporalDecoder` depending on use case.
breaking In version 0.3.0, the `MHA` class no longer accepts `kdim` and `vdim` arguments; use `embed_dim` for all.
fix Remove `kdim` and `vdim` from MHA constructor and ensure all dimensions match `embed_dim`.
gotcha torchscale components expect batch-first tensors (batch, seq, dim), not sequence-first. Incorrect ordering may cause shape mismatches.
fix Ensure input tensors have shape (batch, sequence, features) or use .transpose() if needed.

Initialize MHA and IncrementalDecoder with random input.

import torch
from torchscale.component import MHA
from torchscale.model import IncrementalDecoder

# Example: Multi-head attention
mha = MHA(embed_dim=512, num_heads=8)
x = torch.randn(4, 10, 512)
output = mha(x, x, x)
print(output.shape)

# Example: Decoder
decoder = IncrementalDecoder(
    vocab_size=1000,
    embed_dim=512,
    num_heads=8,
    num_layers=6,
)
tokens = torch.randint(0, 1000, (4, 20))
logits = decoder(tokens)
print(logits.shape)