torchtyping
torchtyping provides runtime type annotations for PyTorch Tensors, allowing developers to specify and dynamically check the shape, dtype, names, and layout of tensors. It aims to improve code clarity and reduce bugs by enforcing consistent tensor properties. The library is currently at version 0.1.5.
Common errors
-
TypeError: Dimension 'batch' of inconsistent size. Got both X and Y.
cause This error occurs at runtime when `patch_typeguard()` is enabled and a `TensorType` annotated function receives tensors with inconsistent dimensions for a named axis. [2]fixEnsure that all tensors passed to the annotated function have consistent sizes for dimensions sharing the same name (e.g., 'batch'). Adjust input tensor shapes or function logic. -
Unknown type constructor TensorType
cause This error arises when attempting to use TorchScript (e.g., `@torch.jit.script`) on functions that have `TensorType` annotations. `TorchScript` does not understand `TensorType` as a valid type constructor. [15]fixRemove `TensorType` annotations from functions that need to be compiled with TorchScript. Consider using regular `torch.Tensor` annotations or a different approach for runtime checks in TorchScripted code. -
Static type checker (e.g., MyPy, Pyright) does not flag errors for incorrect tensor shapes when using `TensorType`.
cause Unlike its successor `jaxtyping`, `torchtyping` is not designed to be compatible with static type checkers. It only provides runtime checks when `typeguard` is enabled. [11, 12]fixThis is expected behavior for `torchtyping`. If static type checking for tensor shapes is desired, migrate to the `jaxtyping` library.
Warnings
- breaking The author strongly recommends migrating to 'jaxtyping' instead of 'torchtyping' for new projects. 'jaxtyping' supports PyTorch, is compatible with static type checkers, and is considered the more polished and easier-to-use successor. [2, 12]
- gotcha If using 'typeguard' for runtime checking, a specific version constraint (`typeguard>=2.11.1,<3`) must be followed. Newer versions of 'typeguard' (3.0.0 and above) are not compatible with 'torchtyping'. [2, 5]
- gotcha TensorType annotations are not compatible with static type checkers (e.g., MyPy, Pyright). This means static analysis tools will not detect incorrect usage of `TensorType` annotations, limiting their utility for compile-time error detection. [11, 12]
- gotcha Functions and modules annotated with `TensorType` are not compatible with TorchScript compilation, resulting in 'Unknown type constructor TensorType' errors. [15]
Install
-
pip install torchtyping
Imports
- TensorType
from torchtyping import TensorType
- patch_typeguard
from torchtyping import patch_typeguard
Quickstart
import torch
from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked
# Call patch_typeguard() once at a global level to enable runtime checking
patch_typeguard()
@typechecked
def add_tensors(x: TensorType['batch'], y: TensorType['batch']) -> TensorType['batch']:
return x + y
# This will work as shapes match
result_ok = add_tensors(rand(3), rand(3))
print(f"Operation successful: {result_ok.shape}")
# This would raise a TypeError due to inconsistent 'batch' dimension
try:
add_tensors(rand(3), rand(1))
except TypeError as e:
print(f"Caught expected error: {e}")