Einsum optimization using opt_einsum and PyTorch FX

0.1.4 · maintenance · verified Wed Apr 15

opt-einsum-fx is a Python library that leverages opt_einsum and PyTorch FX to optimize Einstein summation (einsum) expressions within PyTorch computation graphs. It aims to reduce the overall execution time and memory footprint of complex tensor contractions by intelligently reordering operations. The current version is 0.1.4, with the last release in November 2021, indicating a maintenance-level release cadence.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `opt_einsum_fx` to optimize a PyTorch function containing an `einsum` operation. It involves symbolic tracing the function with `torch.fx.symbolic_trace`, providing example inputs for shape inference, and then applying `opt_einsum_fx.optimize_einsums_full` to get an optimized graph module. The outputs of the original and optimized graphs are compared to ensure correctness.

import torch
import torch.fx
import opt_einsum_fx

def einmatvecmul(a, b, vec):
    """Batched matrix-matrix-vector product using einsum"""
    return torch.einsum("zij,zjk,zk->zi", a, b, vec)

# 1. Create an FX graph module from the function
graph_mod = torch.fx.symbolic_trace(einmatvecmul)

# 2. Define example inputs for shape propagation and optimization
# These shapes are used to determine the optimal contraction path.
example_inputs = (
    torch.randn(7, 4, 5),
    torch.randn(7, 5, 3),
    torch.randn(7, 3)
)

# 3. Optimize the einsums within the FX graph
graph_opt = opt_einsum_fx.optimize_einsums_full(
    model=graph_mod,
    example_inputs=example_inputs
)

# 4. (Optional) Print the optimized code to see the changes
print("Original code:\n", graph_mod.code)
print("Optimized code:\n", graph_opt.code)

# 5. Run the optimized graph and verify correctness
output_original = graph_mod(*example_inputs)
output_optimized = graph_opt(*example_inputs)

assert torch.allclose(output_original, output_optimized)
print("\nOptimization successful and outputs match!")

view raw JSON →