TensorFlow to ONNX Converter

1.17.0 · active · verified Wed Apr 15

tf2onnx is a Python library that enables the conversion of TensorFlow models into the ONNX (Open Neural Network Exchange) format. This allows users to deploy models trained in TensorFlow to various ONNX-compatible runtimes and hardware accelerators. The current version is 1.17.0, with minor releases typically occurring every 1-3 months to keep pace with TensorFlow and ONNX updates.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates converting a simple Keras sequential model (saved as a TensorFlow SavedModel) to the ONNX format using `tf2onnx.convert.from_saved_model`. It includes saving the Keras model, performing the conversion, saving the ONNX output, and optionally verifying the ONNX model's validity.

import tensorflow as tf
import tf2onnx
from onnx.checker import check_model
import shutil
import os

# 1. Create a simple Keras model
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(10, input_shape=(784,), activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# 2. Save the Keras model in the TensorFlow SavedModel format
# This is the recommended approach for converting TF2 models.
tf.saved_model.save(model, "my_keras_model")

# 3. Convert the SavedModel to ONNX
# Specify input_signature and output_names for robust conversion.
onnx_model_proto, _ = tf2onnx.convert.from_saved_model(
    "my_keras_model",
    input_signature=[tf.TensorSpec([None, 784], tf.float32, name="input_0")],
    output_names=["output_0"]
)

# 4. Save the ONNX model to a file
with open("model.onnx", "wb") as f:
    f.write(onnx_model_proto.SerializeToString())

print("ONNX model converted and saved as model.onnx")

# Optional: Verify the ONNX model structure
try:
    check_model(onnx_model_proto)
    print("ONNX model is valid.")
except Exception as e:
    print(f"ONNX model validation failed: {e}")

# Clean up created files/directories
shutil.rmtree("my_keras_model")
os.remove("model.onnx")

view raw JSON →