Torch-ORT

raw JSON →
1.19.2 verified Fri May 01 auth: no python

torch-ort accelerates PyTorch models using ONNX Runtime. Version 1.19.2 supports PyTorch 2.x and ONNX Runtime 1.19.x. Released monthly in sync with ONNX Runtime.

pip install torch-ort
error ModuleNotFoundError: No module named 'onnxruntime'
cause onnxruntime is not installed.
fix
pip install onnxruntime
error RuntimeError: Exporting the operator ... to ONNX is not supported.
cause A PyTorch operator used in the model is not yet supported by ONNX export.
fix
Replace or reimplement the unsupported operator using supported ops. Check the list at https://pytorch.org/docs/stable/onnx.html
error AttributeError: module 'torch_ort' has no attribute 'ORTModule'
cause Using a very old version of torch-ort (pre-1.0) where ORTModule was named differently.
fix
Upgrade torch-ort to the latest version: pip install --upgrade torch-ort
breaking Breaking change: torch-ort v1.14+ requires ONNX Runtime 1.14+. Older ONNX Runtime installations will cause import errors.
fix Upgrade onnxruntime to 1.14+ or pin torch-ort to <1.14.0.
deprecated The 'enable_training' function is deprecated since v1.15.0. Training mode is automatically detected.
fix Remove calls to enable_training; ORTModule now auto-detects training vs inference.
gotcha ORTModule does not support all PyTorch operations. Custom autograd.Function subclasses may fail silently.
fix Test model with ORTModule and fallback to native PyTorch if unsupported ops are encountered. Use ort_model.debug() to trace unsupported operators.
gotcha Models must be in eval mode before conversion? No, ORTModule works in both train and eval, but state dict keys may change if you convert mid-training.
fix Convert model _before_ training or after final eval to avoid key mismatch. Use model.load_state_dict on the original model, not the ORTModule-wrapped one.

Wrap a PyTorch model with ORTModule to accelerate inference/training via ONNX Runtime.

import torch
from torch_ort import ORTModule

model = torch.nn.Linear(10, 5)
ort_model = ORTModule(model)

x = torch.randn(3, 10)
output = ort_model(x)
print(output)