jaxtyping

0.3.9 · active · verified Thu Apr 09

jaxtyping provides type annotations and optional runtime checking for the shape and data type (dtype) of array-like objects across various numerical libraries such as JAX, NumPy, and PyTorch. It extends Python's type hinting system to express array dimensions, allowing for robust static analysis and helping to catch shape-related errors early. The current version is 0.3.9, and it maintains an active development pace with frequent updates.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `jaxtyping` to annotate JAX arrays with shape and dtype information. It defines functions that perform matrix multiplication and array summation, using `Float` and `Int` types with string literal shapes. Crucially, it shows how to enable runtime checking with `set_array_typecheck_enabled(True)` to enforce these annotations, catching shape mismatches at runtime rather than relying solely on static analysis.

from jaxtyping import Array, Float, Int, set_array_typecheck_enabled
import jax
import jax.numpy as jnp

# Enable runtime checks for demonstration
set_array_typecheck_enabled(True)

def matrix_multiply(
    A: Float[Array, 'rows cols'],
    B: Float[Array, 'cols other_cols']
) -> Float[Array, 'rows other_cols']:
    """Multiplies two matrices, checking shapes at runtime."""
    return jnp.matmul(A, B)

def sum_array(
    x: Int[Array, '...']
) -> Int[Array, '']:
    """Sums an array of integers."""
    return jnp.sum(x)

# --- Example Usage ---
key = jax.random.PRNGKey(0)

# Valid multiplication
matrix_A = jax.random.normal(key, (3, 4))
matrix_B = jax.random.normal(key, (4, 5))
result = matrix_multiply(matrix_A, matrix_B)
print(f"Valid matrix multiplication result shape: {result.shape}")

# Invalid multiplication (runtime error if checks are enabled)
try:
    matrix_C = jax.random.normal(key, (3, 5))
    _ = matrix_multiply(matrix_A, matrix_C)
except Exception as e:
    print(f"Caught expected error for invalid shapes: {e.__class__.__name__}: {e}")

# Integer array sum
int_array = jnp.array([1, 2, 3], dtype=jnp.int32)
int_sum = sum_array(int_array)
print(f"Integer array sum: {int_sum}")

view raw JSON →