einx: Universal Notation for Tensor Operations
einx is a Python library that provides a universal interface to formulate tensor operations in frameworks such as Numpy, PyTorch, Jax, Tensorflow, and MLX, using an Einstein-inspired notation. It offers a streamlined approach to complex tensor manipulations, often by compiling operations to backend-specific function calls, which helps minimize overhead. The current version is 0.4.3, with frequent minor releases addressing fixes and adding support for new backends.
Warnings
- gotcha einx itself is a lightweight notation library and does not include its own tensor implementation. It requires a separate tensor framework (e.g., NumPy, PyTorch, JAX, TensorFlow) to be installed and available in your environment to perform operations. For PyTorch, explicitly installing with `pip install einx[torch]` is recommended to ensure compatible backend versions.
- breaking Starting from `v0.2.1`, compiled `einx` functions no longer implicitly include the `einx` namespace in their dependency graph. Instead, they directly import and use the backend's namespace (e.g., `import torch`). If you were previously relying on `einx` being implicitly available within traced or compiled graphs, this change will break such workflows.
- gotcha Version `0.4.3` changed tensor parameter annotations from `typing.TypeVar` to `typing.Any`. This fixed issues where previous strict typing did not always hold (e.g., with mixed-type inputs or backend-dependent output types). While a fix, users relying on strict static analysis with earlier versions might notice changes in type checking behavior or need to adjust their assumptions about tensor type propagation.
- gotcha Version `0.4.0` fully embraced vectorization as its core abstraction, defining expressions by analogy with loop notation. While intended as an improvement for clarity and consistency, users accustomed to an older understanding of `einx` expressions might need to re-evaluate how complex notations are interpreted, especially concerning implicit loop structures and vectorized operations.
Install
-
pip install einx -
pip install einx[torch]
Imports
- einx
import einx
- einx.sum
import einx # ... einx.sum(...)
Quickstart
import einx
import numpy as np # Can be any supported backend like torch, jax, tensorflow, mlx
x = np.ones((10, 20, 30))
print(f"Input shape: {x.shape}")
# Sum-reduction along the second (vectorized) axis
y = einx.sum("a [b] c", x)
print(f"Output shape after sum: {y.shape}")
# Permute and (un)flatten axes with the identity operation
z = einx.id("a (b c) -> (b a) c", x, b=2)
print(f"Output shape after id: {z.shape}")