ONNX to PyTorch Converter

1.6.0 · active · verified Thu Apr 16

onnx2torch is a Python library designed to convert ONNX (Open Neural Network Exchange) models into PyTorch models. It enables users to leverage existing ONNX models within the PyTorch ecosystem, facilitating migration and interoperability between frameworks. The current stable version is 1.6.0, with an active release cadence, typically every few months.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the full cycle of using `onnx2torch`. It starts by creating a dummy PyTorch model, exporting it to an ONNX file, converting that ONNX file back into a PyTorch model using `onnx2torch.convert`, and finally running inference with the newly converted model to show its usage.

import torch
from onnx2torch import convert
import os

# Define paths for dummy model
onnx_model_path = 'dummy_model.onnx'

try:
    # 1. Create a dummy PyTorch model and export it to ONNX
    class SimpleModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = torch.nn.Linear(10, 2)

        def forward(self, x):
            return self.linear(x)

    dummy_torch_model = SimpleModel()
    dummy_input = torch.randn(1, 10) # Batch size 1, 10 features

    torch.onnx.export(dummy_torch_model, dummy_input, onnx_model_path, 
                       opset_version=11, input_names=['input'], output_names=['output'],
                       do_constant_folding=True, # Recommended for stable ONNX
                       dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
    print(f"Dummy ONNX model saved to {onnx_model_path}")

    # 2. Convert the ONNX model to a PyTorch model using onnx2torch
    torch_model_converted = convert(onnx_model_path)
    print("ONNX model converted to PyTorch successfully.")

    # 3. Perform inference with the converted model
    inference_input = torch.randn(2, 10) # Example: batch size 2
    torch_model_converted.eval() # Set to evaluation mode
    with torch.no_grad():
        output = torch_model_converted(inference_input)
    
    print(f"Input shape for inference: {inference_input.shape}")
    print(f"Output shape from converted model: {output.shape}")

    # (Optional) Verify output matches original model if possible
    # original_output = dummy_torch_model(inference_input)
    # print(f"Original model output shape: {original_output.shape}")
    # assert torch.allclose(output, original_output, atol=1e-5), "Outputs do not match!"
    # print("Outputs of original and converted model match (within tolerance).")

finally:
    # Clean up the dummy ONNX file
    if os.path.exists(onnx_model_path):
        os.remove(onnx_model_path)
        print(f"Cleaned up {onnx_model_path}")

view raw JSON →