ONNX Intermediate Representation (IR)
onnx-ir provides an efficient in-memory representation for ONNX graphs, allowing for programmatic creation, manipulation, and optimization of ONNX models in Python. It is currently at version 0.2.0 and has a frequent release cadence, often seeing multiple patch releases per month, indicating active development.
Warnings
- breaking In `v0.1.9`, the `ir.Input` class was deprecated in favor of `ir.val`. Additionally, attribute signatures for nodes were simplified, requiring tuples for repeating attributes. Code using the old `ir.Input` or the previous attribute signature format will break.
- gotcha When using the `Value.replace_all_uses_with` method (introduced in `v0.1.12`) with the `replace_graph_outputs` option, users are responsible for manually assigning the original output name to the replacement value if they wish to preserve the graph's signature. Failing to do so can lead to a modified graph signature.
- gotcha The `v0.2.0` release introduced `sympy` as a new dependency for symbolic shape inferencing. This will increase the total install size and might slightly extend the installation time for new environments.
Install
-
pip install onnx-ir
Imports
- Graph
from onnx_ir import Graph
- Node
from onnx_ir import Node
- Value
from onnx_ir import Value
- Type
from onnx_ir import Type
- Shape
from onnx_ir import Shape
- TensorElementDataType
from onnx_ir import TensorElementDataType
Quickstart
import onnx_ir as ir
import numpy as np
# Create input values with specific types and shapes
input_a = ir.Value("input_a", ir.Type(ir.Shape([2, 2]), ir.TensorElementDataType.FLOAT))
input_b = ir.Value("input_b", ir.Type(ir.Shape([2, 2]), ir.TensorElementDataType.FLOAT))
# Define the output value for the Add node
output_c = ir.Value("output_c")
# Create an 'Add' node with inputs and outputs
add_node = ir.Node("Add", inputs=[input_a, input_b], outputs=[output_c])
# Assemble a graph from inputs, outputs, and nodes
graph = ir.Graph(
[input_a, input_b], # Graph inputs
[output_c], # Graph outputs
[add_node], # Nodes in the graph
"simple_add_graph" # Name of the graph
)
print(f"Created graph: {graph.name}")
print(f"Graph has {len(graph.nodes)} node(s).")
print(f"Input 'input_a' shape: {input_a.type.shape.dims}")