ONNX Script
ONNX Script is a Python library that enables developers to naturally author ONNX functions and models using a subset of Python. It provides tools to translate Python functions into serialized ONNX graphs, offering an expressive, simple, and debuggable way to define ONNX models. The library is actively maintained with frequent patch releases addressing bug fixes and minor improvements.
Warnings
- breaking In v0.6.0, the `.param_schemas` and `schema` properties of `ONNXFunction` were removed. They are replaced by the more flexible `.op_signature` property.
- breaking In v0.5.5, a change to the constant folding pass resulted in the creation of initializers instead of constant nodes. This might affect downstream tools or expectations regarding the ONNX graph structure.
- gotcha ONNX Script only supports a *subset* of Python. Not all Python language constructs (e.g., complex control flows, arbitrary data structures) can be translated into valid ONNX graphs, which can lead to unexpected errors during scripting.
- gotcha The eager mode evaluation of ONNX Script functions is primarily intended for debugging and understanding the function's behavior within Python. It is not optimized for performance and should not be used for high-performance inference.
- gotcha Explicit type annotations for inputs, outputs, and attributes are crucial when defining functions with `@script()`. Missing or incorrect annotations (e.g., for tensor types, shapes, or attribute types like `int`, `float`) can lead to conversion errors or incorrect ONNX graph generation.
Install
-
pip install onnxscript
Imports
- script
from onnxscript import script
- opsetXX
from onnxscript import opset15 as op
- FLOAT
from onnxscript.onnx_types import FLOAT
Quickstart
import onnx
from onnxscript import script, FLOAT
from onnxscript import opset15 as op
import numpy as np
# Define an ONNX function using the @script decorator
@script()
def MatmulAdd(X: FLOAT['N', 'K'], Wt: FLOAT['K', 'M'], Bias: FLOAT['M',]) -> FLOAT['N', 'M']:
return op.MatMul(X, Wt) + Bias
# Create some dummy input data
x_data = np.random.rand(64, 128).astype(np.float32)
wt_data = np.random.rand(128, 10).astype(np.float32)
bias_data = np.random.rand(10,).astype(np.float32)
# Evaluate the ONNX Script function in eager mode (for debugging/testing)
result_eager = MatmulAdd(x_data, wt_data, bias_data)
print(f"Eager mode output shape: {result_eager.shape}")
# Convert the ONNX Script function to an ONNX ModelProto
model_proto = MatmulAdd.to_model_proto(
(x_data, wt_data, bias_data), # Example inputs for tracing shapes
output_names=['output']
)
# Save the ONNX model
onnx_file_path = "matmul_add_model.onnx"
onnx.save(model_proto, onnx_file_path)
print(f"ONNX model saved to {onnx_file_path}")
# Optionally, check the model for validity
try:
onnx.checker.check_model(model_proto)
print("ONNX model is valid!")
except onnx.checker.ValidationError as e:
print(f"ONNX model validation error: {e}")