skl2onnx: Convert Scikit-learn Models to ONNX

1.20.0 · active · verified Sat Apr 11

skl2onnx is a Python library that enables the conversion of scikit-learn machine learning models and pipelines into the ONNX (Open Neural Network Exchange) format. This conversion allows for improved model portability across different runtimes and often leads to enhanced inference performance, especially with ONNX Runtime. The library is actively maintained with frequent releases, typically on a monthly or bi-monthly cadence, and is currently at version 1.20.0.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to train a simple RandomForestClassifier from scikit-learn, convert it to the ONNX format using `skl2onnx.to_onnx`, save the ONNX model to a file, and then load it with ONNX Runtime for inference. It highlights the typical workflow from training to ONNX-based prediction.

import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from skl2onnx import to_onnx
import onnxruntime as rt

# 1. Train a scikit-learn model
iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(np.float32) # ONNX typically uses float32
model = RandomForestClassifier(n_estimators=10, random_state=42)
model.fit(X, y)

# 2. Convert the scikit-learn model to ONNX format
# `X[:1]` is used to infer the input types and shapes
onx_model = to_onnx(model, X[:1])

# 3. Save the ONNX model to a file
with open("rf_iris.onnx", "wb") as f:
    f.write(onx_model.SerializeToString())

# 4. Load and make predictions with ONNX Runtime
sess = rt.InferenceSession("rf_iris.onnx", providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
output_names = [output.name for output in sess.get_outputs()]

# Make a prediction
predictions = sess.run(output_names, {input_name: X[0:1].astype(np.float32)})

print(f"Original model prediction: {model.predict(X[0:1])}")
print(f"ONNX Runtime prediction (label): {predictions[0]}")
print(f"ONNX Runtime prediction (probabilities): {predictions[1]}")

view raw JSON →