ONNX to PyTorch Converter
onnx2torch is a Python library designed to convert ONNX (Open Neural Network Exchange) models into PyTorch models. It enables users to leverage existing ONNX models within the PyTorch ecosystem, facilitating migration and interoperability between frameworks. The current stable version is 1.6.0, with an active release cadence, typically every few months.
Common errors
-
onnx2torch.exception.UnsupportedOperatorException: Unsupported ONNX operator: SomeOperatorName
cause The ONNX model contains an operator that `onnx2torch` does not currently support or recognize.fixCheck the `onnx2torch` documentation or GitHub for the list of supported ONNX operators. You may need to preprocess your ONNX model to remove or replace unsupported operators, or simplify the model graph. -
RuntimeError: shape '[-1, 3, 224, 224]' is invalid for input of size 1x3x224x224
cause The input tensor provided to the converted PyTorch model does not match the expected shape. This often happens with dynamic axes or incorrect input dimensions during inference.fixVerify the expected input shape of your converted model. If the ONNX model used dynamic axes, ensure your inference input tensors adhere to the model's flexible dimensions. Double-check batch sizes, channels, and spatial dimensions. -
AttributeError: module 'onnx2torch' has no attribute 'convert'
cause This error typically occurs if you are using an `onnx2torch` version older than 1.0.0 and attempting to import `convert` from the top-level package, or if there's a typo in the import statement.fixFor `onnx2torch` versions older than 1.0.0, use `from onnx2torch.converter import convert`. For version 1.0.0 and newer, ensure your installation is up-to-date and the import is `from onnx2torch import convert`.
Warnings
- breaking The primary conversion function signature and its import path underwent significant changes in version 1.0.0. The `convert` function is now directly available from the top-level `onnx2torch` package, and its arguments might differ from older versions.
- gotcha Not all ONNX operators are currently implemented or supported by `onnx2torch`. Attempting to convert an ONNX model with unsupported operators will result in an `UnsupportedOperatorException`.
- gotcha Models with dynamic input shapes may not always be handled correctly by `onnx2torch`, leading to conversion failures or incorrect runtime behavior in the converted PyTorch model.
- gotcha Compatibility issues can arise due to the ONNX `opset_version` used during model export. An `opset_version` mismatch with `onnx2torch`'s internal operator definitions can cause conversion errors.
Install
-
pip install onnx2torch
Imports
- convert
from onnx2torch.converter import convert
from onnx2torch import convert
Quickstart
import torch
from onnx2torch import convert
import os
# Define paths for dummy model
onnx_model_path = 'dummy_model.onnx'
try:
# 1. Create a dummy PyTorch model and export it to ONNX
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 2)
def forward(self, x):
return self.linear(x)
dummy_torch_model = SimpleModel()
dummy_input = torch.randn(1, 10) # Batch size 1, 10 features
torch.onnx.export(dummy_torch_model, dummy_input, onnx_model_path,
opset_version=11, input_names=['input'], output_names=['output'],
do_constant_folding=True, # Recommended for stable ONNX
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
print(f"Dummy ONNX model saved to {onnx_model_path}")
# 2. Convert the ONNX model to a PyTorch model using onnx2torch
torch_model_converted = convert(onnx_model_path)
print("ONNX model converted to PyTorch successfully.")
# 3. Perform inference with the converted model
inference_input = torch.randn(2, 10) # Example: batch size 2
torch_model_converted.eval() # Set to evaluation mode
with torch.no_grad():
output = torch_model_converted(inference_input)
print(f"Input shape for inference: {inference_input.shape}")
print(f"Output shape from converted model: {output.shape}")
# (Optional) Verify output matches original model if possible
# original_output = dummy_torch_model(inference_input)
# print(f"Original model output shape: {original_output.shape}")
# assert torch.allclose(output, original_output, atol=1e-5), "Outputs do not match!"
# print("Outputs of original and converted model match (within tolerance).")
finally:
# Clean up the dummy ONNX file
if os.path.exists(onnx_model_path):
os.remove(onnx_model_path)
print(f"Cleaned up {onnx_model_path}")