Haliax: Named Tensors for JAX
Haliax (version 1.3) provides named tensors for JAX, enhancing legibility and reducing common shape-related errors in deep learning models. It builds on JAX's power by allowing users to refer to tensor dimensions by name, simplifying complex operations like broadcasting, reduction, and concatenation. The library is actively developed with frequent minor releases and occasional major updates.
Common errors
-
ValueError: Axis names must be unique within an array.
cause Attempting to create a `NamedArray` or perform an operation where two or more axes within the same array have identical names.fixEnsure all `Axis` objects used for a single `NamedArray` instance or within an operation have distinct names. E.g., `hx.NamedArray(array, (Batch, Batch))` is invalid. -
SignatureMismatchError: Cannot find axis 'Input' on array with axes ('Batch', 'Hidden').cause An operation (e.g., `hx.dot`, `hx.rearrange`) requires an axis with a specific name ('Input'), but the provided `NamedArray` does not have an axis with that name.fixVerify the axes of the input `NamedArray` using `array.axes` and ensure the required axis name matches the operation's expectation. Correct the axis definition or the operation call. -
TypeError: 'Axis' object is not callable
cause Mistakenly attempting to call an `Axis` object as if it were a function or a constructor after it has already been instantiated (e.g., `Batch()`).fixAn `Axis` object (e.g., `Batch = hx.Axis("batch", 4)`) is an instance. Use the instance directly where an axis is expected, such as in a tuple of axes for `NamedArray` or as an argument to an operation.
Warnings
- breaking Major API changes occurred in version 1.0, including refactoring of `haliax.nn` modules and removal of `haliax.partition` functions.
- gotcha Mixing `NamedArray` with raw `jax.Array` or `numpy.ndarray` can lead to loss of named dimension information or shape errors if not handled explicitly.
- gotcha JAX's `jax.jit` compilation can make debugging axis-related errors challenging, as some issues only manifest at runtime after tracing.
- deprecated `AxisSpec` (e.g., tuples of `Axis` objects to specify a dimension) has been largely superseded by directly using `Axis` objects or tuples of `Axis` for clarity.
Install
-
pip install haliax -
pip install haliax jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Imports
- NamedArray
from haliax import NamedArray
- Axis
from haliax import Axis
- product
from haliax import product
- nn
import haliax.nn as nn
- Linear
from haliax import Linear
from haliax.nn import Linear
Quickstart
import haliax as hx
import jax
import jax.random as jr
# 1. Define axes with their names and sizes
Batch = hx.Axis("batch", 4)
Features = hx.Axis("features", 8)
# 2. Create a NamedArray
# The axes argument explicitly lists the named dimensions in order
key = jr.PRNGKey(0)
data_array = hx.random.normal(key, (Batch, Features))
print(f"NamedArray axes: {data_array.axes}")
print(f"Value for batch index 0: {data_array.take(0, Batch).array.round(2)}")
# 3. Perform an operation, e.g., sum over the Features axis
summed_array = data_array.sum(Features)
print(f"Summed array axes: {summed_array.axes}") # Expected: (Batch,)
print(f"Summed array values: {summed_array.array.round(2)}")
# 4. Dot product example
# Define another axis for the second array, same size for contraction
Features2 = hx.Axis("features2", Features.size)
data_array_2 = hx.random.normal(jr.PRNGKey(1), (Features2, Batch))
# Dot product, explicitly contracting over Features and Features2
product_array = hx.dot(data_array, data_array_2, (Features, Features2))
print(f"Dot product array axes: {product_array.axes}") # Expected: (Batch, Batch)
print(f"Dot product values shape: {product_array.array.shape}")