einshape
einshape is a DSL-based reshaping library designed to unify and simplify array manipulation operations such as reshape, squeeze, expand_dims, and transpose, similar to how `einsum` unifies `matmul` and `tensordot`. It primarily targets JAX and TensorFlow frameworks. The current version is 1.0, released in December 2022, indicating a stable but currently infrequent release cadence.
Common errors
-
ModuleNotFoundError: No module named 'jax'
cause The JAX library is not installed, but `jax_einshape` was imported or called.fixInstall JAX and JAXlib: `pip install jax jaxlib`. -
ValueError: Axes lengths incompatible: X and Y
cause The dimensions specified in the `einshape` equation do not allow for a valid reshape operation given the input array's shape, often due to an incorrect grouping, splitting, or missing keyword arguments for new dimensions.fixReview the `einshape` equation and the input array's shape. If splitting dimensions, ensure all necessary sizes are provided via `kwargs` (e.g., `einshape('(ab)c->abc', array, a=expected_a_size)`). -
ValueError: Equation 'ab->a' is not a valid equation. Every index name that is present on the left-hand side of an equation must also be present on the right-hand side.
cause Attempted to drop a dimension (e.g., 'b' in 'ab->a') without using a valid `einshape` operation like implicit squeezing (e.g., `a1b->ab`). The DSL requires all LHS indices to be on RHS, unless implied by other transformations.fixEnsure all dimensions on the left-hand side are accounted for on the right-hand side. For squeezing, use `1` to denote a unit dimension to be removed, e.g., `a1b->ab` instead of `ab->a`.
Warnings
- gotcha einshape does not list JAX or TensorFlow as direct dependencies. Users must install their preferred array backend (e.g., JAX) separately for `einshape` to be functional for array manipulations.
- gotcha Understanding the DSL for grouped dimensions `(components)` and ellipsis `...` is crucial. When splitting a grouped dimension, explicit keyword arguments (e.g., `n=batch_size`) are often required to specify the size of at least one of the new dimensions. Failing to provide these can lead to `ValueError` or incorrect shapes.
- gotcha All index names present on the left-hand side of an `einshape` equation must also be present on the right-hand side, unless they are being implicitly squeezed or combined. Forgetting a dimension or adding an unmatching dimension on the right can lead to shape errors.
Install
-
pip install einshape -
pip install git+https://github.com/deepmind/einshape
Imports
- jax_einshape
from einshape import jax_einshape as einshape
- engine
from einshape import engine
Quickstart
import jax.numpy as jnp
from einshape import jax_einshape as einshape
x = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
print(f"Original shape: {x.shape}\n{x}\n")
# Equivalent to transpose(x, perm=[0,2,1])
y = einshape("abc->acb", x)
print(f"Transposed (abc->acb) shape: {y.shape}\n{y}\n")
# Equivalent to reshape combining leading dimensions
z = einshape("ab...->(ab)...", x)
print(f"Combined leading dims (ab...->(ab)...) shape: {z.shape}\n{z}\n")
# Equivalent to splitting a dimension
w = einshape("(ab)c->abc", z, a=2)
print(f"Split dim ((ab)c->abc) shape: {w.shape}\n{w}")