torchtyping

0.1.5 · maintenance · verified Thu Apr 16

torchtyping provides runtime type annotations for PyTorch Tensors, allowing developers to specify and dynamically check the shape, dtype, names, and layout of tensors. It aims to improve code clarity and reduce bugs by enforcing consistent tensor properties. The library is currently at version 0.1.5.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a function with `TensorType` annotations for tensor shapes. When `typeguard` is installed and `patch_typeguard()` is called, these annotations are checked at runtime. The example shows both a successful operation and an expected runtime `TypeError` when tensor shapes are inconsistent.

import torch
from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

# Call patch_typeguard() once at a global level to enable runtime checking
patch_typeguard()

@typechecked
def add_tensors(x: TensorType['batch'], y: TensorType['batch']) -> TensorType['batch']:
    return x + y

# This will work as shapes match
result_ok = add_tensors(rand(3), rand(3))
print(f"Operation successful: {result_ok.shape}")

# This would raise a TypeError due to inconsistent 'batch' dimension
try:
    add_tensors(rand(3), rand(1))
except TypeError as e:
    print(f"Caught expected error: {e}")

view raw JSON →