Haliax: Named Tensors for JAX

1.3 · active · verified Fri Apr 17

Haliax (version 1.3) provides named tensors for JAX, enhancing legibility and reducing common shape-related errors in deep learning models. It builds on JAX's power by allowing users to refer to tensor dimensions by name, simplifying complex operations like broadcasting, reduction, and concatenation. The library is actively developed with frequent minor releases and occasional major updates.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define named axes, create `NamedArray` instances with these axes, and perform basic operations like reduction and dot products, highlighting how Haliax manages dimension alignment by name.

import haliax as hx
import jax
import jax.random as jr

# 1. Define axes with their names and sizes
Batch = hx.Axis("batch", 4)
Features = hx.Axis("features", 8)

# 2. Create a NamedArray
# The axes argument explicitly lists the named dimensions in order
key = jr.PRNGKey(0)
data_array = hx.random.normal(key, (Batch, Features))

print(f"NamedArray axes: {data_array.axes}")
print(f"Value for batch index 0: {data_array.take(0, Batch).array.round(2)}")

# 3. Perform an operation, e.g., sum over the Features axis
summed_array = data_array.sum(Features)
print(f"Summed array axes: {summed_array.axes}") # Expected: (Batch,)
print(f"Summed array values: {summed_array.array.round(2)}")

# 4. Dot product example
# Define another axis for the second array, same size for contraction
Features2 = hx.Axis("features2", Features.size)
data_array_2 = hx.random.normal(jr.PRNGKey(1), (Features2, Batch))

# Dot product, explicitly contracting over Features and Features2
product_array = hx.dot(data_array, data_array_2, (Features, Features2))
print(f"Dot product array axes: {product_array.axes}") # Expected: (Batch, Batch)
print(f"Dot product values shape: {product_array.array.shape}")

view raw JSON →