Einsum optimization using opt_einsum and PyTorch FX
opt-einsum-fx is a Python library that leverages opt_einsum and PyTorch FX to optimize Einstein summation (einsum) expressions within PyTorch computation graphs. It aims to reduce the overall execution time and memory footprint of complex tensor contractions by intelligently reordering operations. The current version is 0.1.4, with the last release in November 2021, indicating a maintenance-level release cadence.
Warnings
- gotcha The latest release (v0.1.4) explicitly lists compatibility with PyTorch 1.9 and 1.10. While it might work with newer PyTorch versions (e.g., 2.x), direct compatibility with the latest PyTorch versions is not guaranteed and should be tested by the user.
- gotcha `opt_einsum_fx` relies on `torch.fx.symbolic_trace` to build computation graphs. `symbolic_trace` has limitations and may not correctly trace all Python language features or PyTorch operations. Functions with control flow, external data dependencies, or non-traceable operations will fail or produce incorrect graphs.
- gotcha The underlying `opt_einsum` library, used by `opt_einsum_fx`, employs heuristic algorithms to find contraction paths because determining the truly optimal path for einsum expressions is an NP-hard problem. This means the generated 'optimized' path might not always be the absolute best, especially for very complex expressions.
- gotcha Inefficient einsum contraction orders can lead to the creation of extremely large intermediate tensors, potentially causing out-of-memory (OOM) errors. `opt_einsum_fx`'s `EfficientShapeProp` specifically avoids executing einsums during shape propagation to mitigate this, but if `opt_einsum_fx` fails to optimize an expression, or is not applied, such issues can arise.
Install
-
pip install opt_einsum_fx
Imports
- opt_einsum_fx
import opt_einsum_fx
- optimize_einsums_full
from opt_einsum_fx import optimize_einsums_full
Quickstart
import torch
import torch.fx
import opt_einsum_fx
def einmatvecmul(a, b, vec):
"""Batched matrix-matrix-vector product using einsum"""
return torch.einsum("zij,zjk,zk->zi", a, b, vec)
# 1. Create an FX graph module from the function
graph_mod = torch.fx.symbolic_trace(einmatvecmul)
# 2. Define example inputs for shape propagation and optimization
# These shapes are used to determine the optimal contraction path.
example_inputs = (
torch.randn(7, 4, 5),
torch.randn(7, 5, 3),
torch.randn(7, 3)
)
# 3. Optimize the einsums within the FX graph
graph_opt = opt_einsum_fx.optimize_einsums_full(
model=graph_mod,
example_inputs=example_inputs
)
# 4. (Optional) Print the optimized code to see the changes
print("Original code:\n", graph_mod.code)
print("Optimized code:\n", graph_opt.code)
# 5. Run the optimized graph and verify correctness
output_original = graph_mod(*example_inputs)
output_optimized = graph_opt(*example_inputs)
assert torch.allclose(output_original, output_optimized)
print("\nOptimization successful and outputs match!")