ONNX GraphSurgeon

0.6.1 · active · verified Mon Apr 13

ONNX GraphSurgeon is a Python library used for programmatically manipulating ONNX graphs. It provides a high-level API to import, inspect, modify, and export ONNX models, often used in conjunction with NVIDIA TensorRT for optimization and deployment workflows. The current PyPI version is 0.6.1, though its development is closely tied to TensorRT's release cadence, with newer versions often bundled with TensorRT distributions.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to load an ONNX model, find a specific node (Identity), remove it, and replace it with a new operation (an 'Add' node with a new constant input), and then save the modified model. It also includes `onnx.checker.check_model` for basic graph validation.

import onnx
import onnx_graphsurgeon as gs
from onnx import helper, TensorProto
import numpy as np
import os

# 1. Create a dummy ONNX model for demonstration
def create_dummy_model():
    X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 3, 16, 16])
    Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 3, 16, 16])
    
    const_val_tensor_1 = helper.make_tensor('const_add_val_1', TensorProto.FLOAT, [], [1.0])
    
    # Model: X -> Add (with const_add_val_1) -> Add_Output -> Identity -> Y
    node_add = helper.make_node('Add', ['X', 'const_add_val_1'], ['Add_Output'])
    node_identity = helper.make_node('Identity', ['Add_Output'], ['Y'])

    graph_def = helper.make_graph(
        [node_add, node_identity],
        'simple_graph',
        [X],
        [Y],
        [const_val_tensor_1]
    )
    model = helper.make_model(graph_def, producer_name='dummy-model', opset_imports=[helper.make_opsetid("", 13)])
    return model

# Save the dummy model to a file
dummy_model = create_dummy_model()
onnx.save(dummy_model, "dummy_model.onnx")
print("Original model saved to dummy_model.onnx")

# 2. Load the model with ONNX GraphSurgeon
graph = gs.import_onnx(onnx.load("dummy_model.onnx"))

# 3. Modify the graph: Replace the Identity node with a second Add node
identity_node = None
for node in graph.nodes:
    if node.op == "Identity":
        identity_node = node
        break

if identity_node:
    input_tensor = identity_node.inputs[0]  # Output of the first Add node
    output_tensor = identity_node.outputs[0] # The graph's final output tensor 'Y'

    # Remove the old Identity node
    graph.nodes.remove(identity_node)

    # Create a new Constant for the second Add op
    const_val_tensor_2 = gs.Constant(name="const_add_val_2", values=np.array([2.0], dtype=np.float32))

    # Create a new Add node to replace Identity
    new_add_node = gs.Node(
        op="Add",
        inputs=[input_tensor, const_val_tensor_2], # Input from first Add, plus new constant
        outputs=[output_tensor] # Re-use the original output tensor 'Y'
    )

    # Add the new node to the graph
    graph.nodes.append(new_add_node)

# Always cleanup and topological sort after graph modifications
graph.cleanup().toposort()

# 4. Save the modified model
modified_model = gs.export_onnx(graph)
onnx.save(modified_model, "modified_dummy_model.onnx")
print("Modified model saved to modified_dummy_model.onnx (Identity replaced with Add)")

# Optional: Verify the modified model with ONNX checker
try:
    onnx.checker.check_model(modified_model)
    print("Modified model check successful!")
except Exception as e:
    print(f"Modified model check failed: {e}")

# Cleanup created files
os.remove("dummy_model.onnx")
os.remove("modified_dummy_model.onnx")

view raw JSON →