einops-exts: Einops Extensions
einops-exts provides personal helper functions and extensions for the `einops` tensor manipulation library, primarily focusing on deep learning frameworks. It is currently at version 0.0.4 and has an irregular release cadence, with the latest release in January 2023.
Warnings
- gotcha As an extension library, `einops-exts` critically depends on `einops`. Incorrect `einops` pattern strings within `einops-exts` operations will lead to runtime errors due to shape mismatches or invalid axis compositions/decompositions. Users should be familiar with `einops` notation.
- gotcha While the core `einops` library is framework-agnostic, `einops-exts` primarily provides utilities for PyTorch (e.g., `einops_exts.torch`). Users attempting to use it with other frameworks (TensorFlow, JAX, NumPy) should verify if specific extensions exist for those backends, as direct use of PyTorch-specific modules will fail.
- breaking The library is in an early development stage (0.0.x versioning). While no explicit breaking changes are documented for `einops-exts` itself, its API might evolve rapidly, and changes in its core dependency (`einops`) could indirectly lead to breaking behavior.
Install
-
pip install einops-exts
Imports
- EinopsToAndFrom
from einops_exts.torch import EinopsToAndFrom
Quickstart
import torch
from torch import nn
from einops_exts.torch import EinopsToAndFrom
# Define a simple PyTorch model using EinopsToAndFrom
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.transform = EinopsToAndFrom('b h w c -> b c h w', 'b c h w -> b h w c', nn.Identity())
self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
def forward(self, x):
# Input x is expected as (batch, height, width, channels)
x = self.transform(x) # Transforms to (batch, channels, height, width) for Conv2d
x = self.conv(x)
x = self.transform(x) # Transforms back to (batch, height, width, channels)
return x
# Example usage
model = MyModel()
input_tensor = torch.randn(1, 64, 64, 3) # Batch 1, 64x64, 3 channels
output_tensor = model(input_tensor)
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")