jaxtyping
jaxtyping provides type annotations and optional runtime checking for the shape and data type (dtype) of array-like objects across various numerical libraries such as JAX, NumPy, and PyTorch. It extends Python's type hinting system to express array dimensions, allowing for robust static analysis and helping to catch shape-related errors early. The current version is 0.3.9, and it maintains an active development pace with frequent updates.
Warnings
- gotcha jaxtyping annotations are purely static by default. To enable runtime shape and dtype checking, you must explicitly call `jaxtyping.set_array_typecheck_enabled(True)` somewhere at the start of your program. Without this, shape errors will only be caught by static type checkers.
- breaking The `DType` type (used for annotating the data type of an array) was removed in version 0.3.0. This was done to simplify the API and resolve conflicts with PEP 646. Code using `DType` will no longer work.
- deprecated The `jaxtyping.set_active` function, previously used to enable/disable runtime checks, has been deprecated. It has been replaced by `jaxtyping.set_array_typecheck_enabled` for clearer intent.
Install
-
pip install jaxtyping
Imports
- Array
from jaxtyping import Array
- Float
from jaxtyping import Float
- Int
from jaxtyping import Int
- set_array_typecheck_enabled
from jaxtyping import set_array_typecheck_enabled
Quickstart
from jaxtyping import Array, Float, Int, set_array_typecheck_enabled
import jax
import jax.numpy as jnp
# Enable runtime checks for demonstration
set_array_typecheck_enabled(True)
def matrix_multiply(
A: Float[Array, 'rows cols'],
B: Float[Array, 'cols other_cols']
) -> Float[Array, 'rows other_cols']:
"""Multiplies two matrices, checking shapes at runtime."""
return jnp.matmul(A, B)
def sum_array(
x: Int[Array, '...']
) -> Int[Array, '']:
"""Sums an array of integers."""
return jnp.sum(x)
# --- Example Usage ---
key = jax.random.PRNGKey(0)
# Valid multiplication
matrix_A = jax.random.normal(key, (3, 4))
matrix_B = jax.random.normal(key, (4, 5))
result = matrix_multiply(matrix_A, matrix_B)
print(f"Valid matrix multiplication result shape: {result.shape}")
# Invalid multiplication (runtime error if checks are enabled)
try:
matrix_C = jax.random.normal(key, (3, 5))
_ = matrix_multiply(matrix_A, matrix_C)
except Exception as e:
print(f"Caught expected error for invalid shapes: {e.__class__.__name__}: {e}")
# Integer array sum
int_array = jnp.array([1, 2, 3], dtype=jnp.int32)
int_sum = sum_array(int_array)
print(f"Integer array sum: {int_sum}")