Torch Einops Utils

0.0.30 · active · verified Thu Apr 16

torch-einops-utils is a collection of personal utility functions designed to work with PyTorch and Einops, providing convenient abstractions for common tensor manipulations. It is currently at version 0.0.30 and receives frequent updates, indicating active development with rapid iteration.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `EinopsToAndFrom`, one of the core utility classes. It shows how to define a custom module that wraps a function or `nn.Module` with specified Einops input and output patterns. The example uses `nn.Identity` and a simple lambda function to illustrate its application.

import torch
from torch import nn
from torch_einops_utils import EinopsToAndFrom

class Foo(EinopsToAndFrom):
    def __init__(self, fn: nn.Module):
        # EinopsToAndFrom requires input pattern, output pattern, and a callable/nn.Module
        super().__init__('b n d', 'b n d', fn)

    def forward(self, x):
        # The `fn` provided in __init__ is called within EinopsToAndFrom's forward
        # after applying the input pattern, and before applying the output pattern.
        # In this example, 'b n d' -> 'b n d' is a no-op rearrangement
        # so the fn acts directly on the input shape.
        return self.fn(x)

# Example usage with a simple nn.Identity
model = Foo(nn.Identity())
x = torch.randn(1, 10, 32) # Batch, Sequence Length, Dimension
y = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")

# Example with a lambda function
dummy_fn = lambda z: z * 2 # Multiply by 2
model_lambda = Foo(dummy_fn)
y_lambda = model_lambda(x)
print(f"Output with lambda: {y_lambda.shape}")
print(f"First element value: {y_lambda[0,0,0]:.2f}")

view raw JSON →