einx: Universal Notation for Tensor Operations

0.4.3 · active · verified Sat Apr 11

einx is a Python library that provides a universal interface to formulate tensor operations in frameworks such as Numpy, PyTorch, Jax, Tensorflow, and MLX, using an Einstein-inspired notation. It offers a streamlined approach to complex tensor manipulations, often by compiling operations to backend-specific function calls, which helps minimize overhead. The current version is 0.4.3, with frequent minor releases addressing fixes and adding support for new backends.

Warnings

Install

Imports

Quickstart

This example demonstrates basic tensor operations using `einx` with a NumPy array. `einx` automatically detects and uses an available backend (e.g., NumPy, PyTorch, JAX) for the tensor operations. The string notation defines how axes are manipulated.

import einx
import numpy as np # Can be any supported backend like torch, jax, tensorflow, mlx

x = np.ones((10, 20, 30))
print(f"Input shape: {x.shape}")

# Sum-reduction along the second (vectorized) axis
y = einx.sum("a [b] c", x)
print(f"Output shape after sum: {y.shape}")

# Permute and (un)flatten axes with the identity operation
z = einx.id("a (b c) -> (b a) c", x, b=2)
print(f"Output shape after id: {z.shape}")

view raw JSON →