ONNX to TensorFlow/TFLite Converter

2.4.0 · active · verified Sun Apr 12

onnx2tf is a versatile Python tool designed for converting ONNX model files into various target formats, including LiteRT, TFLite, TensorFlow SavedModel, PyTorch native code (nn.Module), TorchScript (.pt), state_dict (.pt), Exported Program (.pt2), and Dynamo ONNX. It also supports direct conversion from LiteRT to PyTorch. The library maintains a rapid release cadence, with version 2.4.0 being the latest stable release.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the end-to-end process of defining a simple PyTorch model, exporting it to ONNX format, and then using `onnx2tf` to convert the ONNX model into a TensorFlow SavedModel. It highlights the primary `onnx2tf.convert()` function and the necessary input/output paths.

import onnx2tf
import torch
import torch.nn as nn
import os

# 1. Define a simple PyTorch model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)

    def forward(self, x):
        return self.conv(x)

# 2. Instantiate and export to ONNX
model = SimpleModel()
dummy_input = torch.randn(1, 3, 224, 224)
onnx_file_path = "simple_model.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_file_path,
    opset_version=17,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)
print(f"PyTorch model exported to {onnx_file_path}")

# 3. Convert ONNX to TensorFlow SavedModel
output_folder = "./converted_tf_model"
os.makedirs(output_folder, exist_ok=True)

onnx2tf.convert(
    input_onnx_file_path=onnx_file_path,
    output_folder_path=output_folder,
    # For TFLite conversion, you might add:
    # tflite_output_file_path="./converted_tf_model/model.tflite"
)
print(f"ONNX model converted to TensorFlow SavedModel at {output_folder}")

# Clean up generated ONNX file
os.remove(onnx_file_path)

# To use a specific backend for TFLite (e.g., the deprecated tf_converter):
# onnx2tf.convert(
#     input_onnx_file_path=onnx_file_path,
#     output_folder_path="./converted_tf_model_tfconv",
#     tflite_output_file_path="./converted_tf_model_tfconv/model.tflite",
#     tflite_backend='tf_converter'
# )

view raw JSON →