ONNX GraphSurgeon
ONNX GraphSurgeon is a Python library used for programmatically manipulating ONNX graphs. It provides a high-level API to import, inspect, modify, and export ONNX models, often used in conjunction with NVIDIA TensorRT for optimization and deployment workflows. The current PyPI version is 0.6.1, though its development is closely tied to TensorRT's release cadence, with newer versions often bundled with TensorRT distributions.
Warnings
- breaking The PyPI package `onnx-graphsurgeon` (v0.6.1) might be significantly older than the version integrated with current NVIDIA TensorRT releases (e.g., TensorRT 10.x). For optimal compatibility, especially with new ONNX operators or TensorRT features, it is often recommended to use the `onnx-graphsurgeon` version bundled with your specific TensorRT installation or build from source.
- gotcha ONNX GraphSurgeon does not automatically validate the semantic correctness of graph modifications. Incorrect changes can lead to invalid ONNX graphs that fail `onnx.checker.check_model()` or cannot be parsed/optimized by runtimes like TensorRT.
- gotcha Careful management of `gs.Tensor` objects (inputs/outputs) is crucial during graph manipulation. Incorrectly linking tensors or failing to update consumer/producer relationships can result in a broken graph. Using `graph.cleanup().toposort()` is essential after modifications.
Install
-
pip install onnx-graphsurgeon==0.6.1
Imports
- gs
import onnx_graphsurgeon as gs
- Graph
from onnx_graphsurgeon.ir.graph import Graph
- Node
from onnx_graphsurgeon.ir.node import Node
- Tensor
from onnx_graphsurgeon.ir.tensor import Tensor
Quickstart
import onnx
import onnx_graphsurgeon as gs
from onnx import helper, TensorProto
import numpy as np
import os
# 1. Create a dummy ONNX model for demonstration
def create_dummy_model():
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 3, 16, 16])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 3, 16, 16])
const_val_tensor_1 = helper.make_tensor('const_add_val_1', TensorProto.FLOAT, [], [1.0])
# Model: X -> Add (with const_add_val_1) -> Add_Output -> Identity -> Y
node_add = helper.make_node('Add', ['X', 'const_add_val_1'], ['Add_Output'])
node_identity = helper.make_node('Identity', ['Add_Output'], ['Y'])
graph_def = helper.make_graph(
[node_add, node_identity],
'simple_graph',
[X],
[Y],
[const_val_tensor_1]
)
model = helper.make_model(graph_def, producer_name='dummy-model', opset_imports=[helper.make_opsetid("", 13)])
return model
# Save the dummy model to a file
dummy_model = create_dummy_model()
onnx.save(dummy_model, "dummy_model.onnx")
print("Original model saved to dummy_model.onnx")
# 2. Load the model with ONNX GraphSurgeon
graph = gs.import_onnx(onnx.load("dummy_model.onnx"))
# 3. Modify the graph: Replace the Identity node with a second Add node
identity_node = None
for node in graph.nodes:
if node.op == "Identity":
identity_node = node
break
if identity_node:
input_tensor = identity_node.inputs[0] # Output of the first Add node
output_tensor = identity_node.outputs[0] # The graph's final output tensor 'Y'
# Remove the old Identity node
graph.nodes.remove(identity_node)
# Create a new Constant for the second Add op
const_val_tensor_2 = gs.Constant(name="const_add_val_2", values=np.array([2.0], dtype=np.float32))
# Create a new Add node to replace Identity
new_add_node = gs.Node(
op="Add",
inputs=[input_tensor, const_val_tensor_2], # Input from first Add, plus new constant
outputs=[output_tensor] # Re-use the original output tensor 'Y'
)
# Add the new node to the graph
graph.nodes.append(new_add_node)
# Always cleanup and topological sort after graph modifications
graph.cleanup().toposort()
# 4. Save the modified model
modified_model = gs.export_onnx(graph)
onnx.save(modified_model, "modified_dummy_model.onnx")
print("Modified model saved to modified_dummy_model.onnx (Identity replaced with Add)")
# Optional: Verify the modified model with ONNX checker
try:
onnx.checker.check_model(modified_model)
print("Modified model check successful!")
except Exception as e:
print(f"Modified model check failed: {e}")
# Cleanup created files
os.remove("dummy_model.onnx")
os.remove("modified_dummy_model.onnx")