ONNX Simplifier
ONNX Simplifier (onnxsim) is a Python library designed to reduce the complexity of ONNX models by inferring the computation graph and performing constant folding. This makes ONNX models more efficient for inference and deployment. It is actively maintained with frequent minor releases, currently at version 0.6.2.
Warnings
- gotcha Building `onnxsim` from source (e.g., if pre-built wheels are unavailable for your specific platform/Python version) requires `cmake` and a C++ compiler. Installation may fail without these system-level dependencies.
- gotcha Models exceeding a protobuf size limit (typically 2GB) can cause errors during loading or simplification due to limitations in the underlying protobuf library used by ONNX.
- gotcha ONNX models with graphs that are not topologically sorted may result in validation errors during simplification.
- deprecated The `--enable-fuse-bn` command-line argument for fusing batch normalization into convolutional layers is deprecated as this optimization is now enabled by default.
- gotcha When simplifying models with dynamic input shapes, you often need to explicitly provide the expected input shapes using the `--input-shape` argument in the CLI or `input_shapes` parameter in the `simplify` function.
Install
-
pip install onnxsim
Imports
- simplify
from onnxsim import simplify
Quickstart
import onnx
from onnxsim import simplify
import torch
import torch.nn as nn
import os
# Create a dummy ONNX model for demonstration
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
model = SimpleNet()
dummy_input = torch.randn(1, 10)
input_model_path = "dummy_model.onnx"
output_model_path = "dummy_model_simplified.onnx"
torch.onnx.export(model, dummy_input, input_model_path,
input_names=['input'], output_names=['output'])
# Load your predefined ONNX model
onnx_model = onnx.load(input_model_path)
# Convert model
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
# Save the simplified model
onnx.save(model_simp, output_model_path)
print(f"Model simplified and saved to {output_model_path}")
# Clean up dummy models
os.remove(input_model_path)
os.remove(output_model_path)