Torch-Struct

raw JSON →
0.5 verified Fri May 01 auth: no python

A library for structured prediction (e.g., parsing, sequence labeling) with PyTorch. Provides differentiable implementations of dynamic programming algorithms like CYK, Inside-Outside, and Viterbi. Current version 0.5, last updated 2022. Low release cadence.

pip install torch-struct
error AttributeError: module 'torch_struct' has no attribute 'Semiring'
cause Semiring was removed in v0.5 or used via incorrect import.
fix
Use StructDistribution instead: from torch_struct import StructDistribution.
error RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 2
cause Potentials tensor shape (batch, N, C) does not match the expected shape for the chosen struct (e.g., for Trees, you need (batch, N, N, C) or similar).
fix
Check the expected input shape for the struct class. For LineCRF, use (batch, N, C). For TreeCRF, use (batch, N, N, C).
error ValueError: Expected more than 1 value per channel when training, got input size ...
cause Batch normalization layer in the model receiving too few samples; not directly a torch-struct issue but common when integrating.
fix
Ensure batch size > 1 or set model.eval() during inference.
breaking In v0.5, the `genbmm` dependency was dropped and some lesser-used features were removed (e.g., NeuralPottsModel). If you upgrade from v0.4, code relying on those features will break.
fix Remove `genbmm` imports; use torch.bmm or einsum instead. Replace removed models with custom implementations.
deprecated The `Semiring` API is deprecated in favor of `StructDistribution`. Old code using `Semiring` will still work but generate deprecation warnings.
fix Replace `Semiring` usage with `StructDistribution` and log-potential pattern.
gotcha Potentials must be in log-space (logits). The library does not check for this; using raw probabilities leads to incorrect results.
fix Ensure inputs are log-space (e.g., use torch.log(probs) if you have probabilities).
gotcha Many struct classes require the length of each sequence (mask) for variable-length sequences. Without a mask, the algorithms assume fully-observed sequences of length N, which can produce incorrect gradients.
fix Pass a `lengths` tensor of shape (batch,) to the constructor: `dist = LineCRF(log_potentials, lengths=lengths)`.

Creates a linear-chain CRF from random potentials and computes the log-partition function.

import torch
import torch_struct

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch, N, C = 2, 5, 3
log_potentials = torch.randn(batch, N, C, device=device)
dist = torch_struct.LineCRF(log_potentials)
log_partition = dist.partition
print(f"Log partition: {log_partition}")