{"id":2078,"library":"jaxtyping","title":"jaxtyping","description":"jaxtyping provides type annotations and optional runtime checking for the shape and data type (dtype) of array-like objects across various numerical libraries such as JAX, NumPy, and PyTorch. It extends Python's type hinting system to express array dimensions, allowing for robust static analysis and helping to catch shape-related errors early. The current version is 0.3.9, and it maintains an active development pace with frequent updates.","status":"active","version":"0.3.9","language":"en","source_language":"en","source_url":"https://github.com/google/jaxtyping","tags":["jax","numpy","pytorch","type-checking","annotations","array-shapes"],"install":[{"cmd":"pip install jaxtyping","lang":"bash","label":"Install core library"}],"dependencies":[{"reason":"Required for JAX array type checking.","package":"jax","optional":true},{"reason":"Required for NumPy array type checking.","package":"numpy","optional":true},{"reason":"Required for PyTorch tensor type checking.","package":"torch","optional":true}],"imports":[{"note":"The base type for shape-annotated arrays across all supported backends.","symbol":"Array","correct":"from jaxtyping import Array"},{"note":"Type for floating-point arrays.","symbol":"Float","correct":"from jaxtyping import Float"},{"note":"Type for integer arrays.","symbol":"Int","correct":"from jaxtyping import Int"},{"note":"Function to enable/disable runtime shape and dtype checking.","symbol":"set_array_typecheck_enabled","correct":"from jaxtyping import set_array_typecheck_enabled"}],"quickstart":{"code":"from jaxtyping import Array, Float, Int, set_array_typecheck_enabled\nimport jax\nimport jax.numpy as jnp\n\n# Enable runtime checks for demonstration\nset_array_typecheck_enabled(True)\n\ndef matrix_multiply(\n    A: Float[Array, 'rows cols'],\n    B: Float[Array, 'cols other_cols']\n) -> Float[Array, 'rows other_cols']:\n    \"\"\"Multiplies two matrices, checking shapes at runtime.\"\"\"\n    return jnp.matmul(A, B)\n\ndef sum_array(\n    x: Int[Array, '...']\n) -> Int[Array, '']:\n    \"\"\"Sums an array of integers.\"\"\"\n    return jnp.sum(x)\n\n# --- Example Usage ---\nkey = jax.random.PRNGKey(0)\n\n# Valid multiplication\nmatrix_A = jax.random.normal(key, (3, 4))\nmatrix_B = jax.random.normal(key, (4, 5))\nresult = matrix_multiply(matrix_A, matrix_B)\nprint(f\"Valid matrix multiplication result shape: {result.shape}\")\n\n# Invalid multiplication (runtime error if checks are enabled)\ntry:\n    matrix_C = jax.random.normal(key, (3, 5))\n    _ = matrix_multiply(matrix_A, matrix_C)\nexcept Exception as e:\n    print(f\"Caught expected error for invalid shapes: {e.__class__.__name__}: {e}\")\n\n# Integer array sum\nint_array = jnp.array([1, 2, 3], dtype=jnp.int32)\nint_sum = sum_array(int_array)\nprint(f\"Integer array sum: {int_sum}\")","lang":"python","description":"This quickstart demonstrates how to use `jaxtyping` to annotate JAX arrays with shape and dtype information. It defines functions that perform matrix multiplication and array summation, using `Float` and `Int` types with string literal shapes. Crucially, it shows how to enable runtime checking with `set_array_typecheck_enabled(True)` to enforce these annotations, catching shape mismatches at runtime rather than relying solely on static analysis."},"warnings":[{"fix":"Add `from jaxtyping import set_array_typecheck_enabled; set_array_typecheck_enabled(True)` to your application's entry point if you desire runtime validation.","message":"jaxtyping annotations are purely static by default. To enable runtime shape and dtype checking, you must explicitly call `jaxtyping.set_array_typecheck_enabled(True)` somewhere at the start of your program. Without this, shape errors will only be caught by static type checkers.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Replace `DType` annotations with direct Python type hints like `float`, `int`, or JAX/NumPy dtypes (e.g., `jnp.float32`) within the `Array` type, e.g., `Array[float, '...']` instead of `Array[DType[float], '...']`.","message":"The `DType` type (used for annotating the data type of an array) was removed in version 0.3.0. This was done to simplify the API and resolve conflicts with PEP 646. Code using `DType` will no longer work.","severity":"breaking","affected_versions":">=0.3.0"},{"fix":"Update calls from `jaxtyping.set_active(True/False)` to `jaxtyping.set_array_typecheck_enabled(True/False)`.","message":"The `jaxtyping.set_active` function, previously used to enable/disable runtime checks, has been deprecated. It has been replaced by `jaxtyping.set_array_typecheck_enabled` for clearer intent.","severity":"deprecated","affected_versions":">=0.2.x"}],"env_vars":null,"last_verified":"2026-04-09T00:00:00.000Z","next_check":"2026-07-08T00:00:00.000Z"}