einshape

1.0 · active · verified Thu Apr 16

einshape is a DSL-based reshaping library designed to unify and simplify array manipulation operations such as reshape, squeeze, expand_dims, and transpose, similar to how `einsum` unifies `matmul` and `tensordot`. It primarily targets JAX and TensorFlow frameworks. The current version is 1.0, released in December 2022, indicating a stable but currently infrequent release cadence.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates basic reshaping operations using `einshape` with JAX. It covers transposing dimensions, combining multiple leading dimensions, and splitting a dimension, highlighting the DSL syntax. Note that JAX must be installed separately for this example to run.

import jax.numpy as jnp
from einshape import jax_einshape as einshape

x = jnp.arange(2 * 3 * 4).reshape((2, 3, 4))
print(f"Original shape: {x.shape}\n{x}\n")

# Equivalent to transpose(x, perm=[0,2,1])
y = einshape("abc->acb", x)
print(f"Transposed (abc->acb) shape: {y.shape}\n{y}\n")

# Equivalent to reshape combining leading dimensions
z = einshape("ab...->(ab)...", x)
print(f"Combined leading dims (ab...->(ab)...) shape: {z.shape}\n{z}\n")

# Equivalent to splitting a dimension
w = einshape("(ab)c->abc", z, a=2)
print(f"Split dim ((ab)c->abc) shape: {w.shape}\n{w}")

view raw JSON →