{"id":9804,"library":"haliax","title":"Haliax: Named Tensors for JAX","description":"Haliax (version 1.3) provides named tensors for JAX, enhancing legibility and reducing common shape-related errors in deep learning models. It builds on JAX's power by allowing users to refer to tensor dimensions by name, simplifying complex operations like broadcasting, reduction, and concatenation. The library is actively developed with frequent minor releases and occasional major updates.","status":"active","version":"1.3","language":"en","source_language":"en","source_url":"https://github.com/stanford-crfm/haliax","tags":["jax","deep-learning","named-tensors","machine-learning"],"install":[{"cmd":"pip install haliax","lang":"bash","label":"Install Haliax (CPU JAX)"},{"cmd":"pip install haliax jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html","lang":"bash","label":"Install Haliax (CUDA JAX)"}],"dependencies":[{"reason":"Core dependency for tensor computation, Haliax is built on JAX.","package":"jax"},{"reason":"Fundamental array operations and interoperability.","package":"numpy"},{"reason":"Provides type hints for advanced Python features.","package":"typing_extensions"}],"imports":[{"symbol":"NamedArray","correct":"from haliax import NamedArray"},{"symbol":"Axis","correct":"from haliax import Axis"},{"symbol":"product","correct":"from haliax import product"},{"symbol":"nn","correct":"import haliax.nn as nn"},{"note":"Linear is part of the haliax.nn submodule, not top-level.","wrong":"from haliax import Linear","symbol":"Linear","correct":"from haliax.nn import Linear"}],"quickstart":{"code":"import haliax as hx\nimport jax\nimport jax.random as jr\n\n# 1. Define axes with their names and sizes\nBatch = hx.Axis(\"batch\", 4)\nFeatures = hx.Axis(\"features\", 8)\n\n# 2. Create a NamedArray\n# The axes argument explicitly lists the named dimensions in order\nkey = jr.PRNGKey(0)\ndata_array = hx.random.normal(key, (Batch, Features))\n\nprint(f\"NamedArray axes: {data_array.axes}\")\nprint(f\"Value for batch index 0: {data_array.take(0, Batch).array.round(2)}\")\n\n# 3. Perform an operation, e.g., sum over the Features axis\nsummed_array = data_array.sum(Features)\nprint(f\"Summed array axes: {summed_array.axes}\") # Expected: (Batch,)\nprint(f\"Summed array values: {summed_array.array.round(2)}\")\n\n# 4. Dot product example\n# Define another axis for the second array, same size for contraction\nFeatures2 = hx.Axis(\"features2\", Features.size)\ndata_array_2 = hx.random.normal(jr.PRNGKey(1), (Features2, Batch))\n\n# Dot product, explicitly contracting over Features and Features2\nproduct_array = hx.dot(data_array, data_array_2, (Features, Features2))\nprint(f\"Dot product array axes: {product_array.axes}\") # Expected: (Batch, Batch)\nprint(f\"Dot product values shape: {product_array.array.shape}\")","lang":"python","description":"This quickstart demonstrates how to define named axes, create `NamedArray` instances with these axes, and perform basic operations like reduction and dot products, highlighting how Haliax manages dimension alignment by name."},"warnings":[{"fix":"Consult the Haliax 1.0+ documentation for updated API calls, especially for neural network layers and partitioning utilities. Many functions were integrated into `NamedArray` methods.","message":"Major API changes occurred in version 1.0, including refactoring of `haliax.nn` modules and removal of `haliax.partition` functions.","severity":"breaking","affected_versions":"<1.0"},{"fix":"Use `NamedArray.array` to access the underlying JAX array when interfacing with JAX functions not aware of Haliax. Convert back using `hx.NamedArray(raw_array, axes)` to restore named dimension benefits.","message":"Mixing `NamedArray` with raw `jax.Array` or `numpy.ndarray` can lead to loss of named dimension information or shape errors if not handled explicitly.","severity":"gotcha","affected_versions":"All"},{"fix":"Debug code without `jax.jit` first. Ensure `Axis` objects are hashable and defined globally or passed consistently, avoiding their creation inside `jit`-ted functions where they might be re-instantiated.","message":"JAX's `jax.jit` compilation can make debugging axis-related errors challenging, as some issues only manifest at runtime after tracing.","severity":"gotcha","affected_versions":"All"},{"fix":"Where possible, replace `AxisSpec` patterns with direct `Axis` objects or explicit tuples of `Axis` instances to define dimensions.","message":"`AxisSpec` (e.g., tuples of `Axis` objects to specify a dimension) has been largely superseded by directly using `Axis` objects or tuples of `Axis` for clarity.","severity":"deprecated","affected_versions":">=0.10.0"}],"env_vars":null,"last_verified":"2026-04-17T00:00:00.000Z","next_check":"2026-07-16T00:00:00.000Z","problems":[{"fix":"Ensure all `Axis` objects used for a single `NamedArray` instance or within an operation have distinct names. E.g., `hx.NamedArray(array, (Batch, Batch))` is invalid.","cause":"Attempting to create a `NamedArray` or perform an operation where two or more axes within the same array have identical names.","error":"ValueError: Axis names must be unique within an array."},{"fix":"Verify the axes of the input `NamedArray` using `array.axes` and ensure the required axis name matches the operation's expectation. Correct the axis definition or the operation call.","cause":"An operation (e.g., `hx.dot`, `hx.rearrange`) requires an axis with a specific name ('Input'), but the provided `NamedArray` does not have an axis with that name.","error":"SignatureMismatchError: Cannot find axis 'Input' on array with axes ('Batch', 'Hidden')."},{"fix":"An `Axis` object (e.g., `Batch = hx.Axis(\"batch\", 4)`) is an instance. Use the instance directly where an axis is expected, such as in a tuple of axes for `NamedArray` or as an argument to an operation.","cause":"Mistakenly attempting to call an `Axis` object as if it were a function or a constructor after it has already been instantiated (e.g., `Batch()`).","error":"TypeError: 'Axis' object is not callable"}]}