Orbax Export
Orbax Export is a Python library designed for JAX users, providing utilities to serialize and export JAX models into the TensorFlow SavedModel format. It is a specialized component within the broader Orbax ecosystem, which offers common training utilities for JAX. The library is actively maintained, with its latest version being 0.0.8, released in September 2025, and maintains a regular release cadence to support JAX users.
Warnings
- gotcha Orbax Export requires TensorFlow for its core functionality (exporting to SavedModel), but TensorFlow is *not* installed by default. Users must explicitly install it using `pip install tensorflow` or by installing `orbax-export` with the `[all]` extra: `pip install orbax-export[all]`.
- breaking The original `orbax` PyPI package (frozen at 0.1.6/0.1.9) is no longer the primary installation target. Users should directly install specific sub-packages like `orbax-export` or `orbax-checkpoint` to avoid dependency bloat. While existing `from orbax import export` statements may still work due to namespace preservation, installing the specific sub-package is the recommended approach.
- gotcha There can be version incompatibilities between `orbax-export` and `orbax-checkpoint` (e.g., `orbax-export 0.0.5` was incompatible with `orbax-checkpoint 0.9.0`). When using both, always ensure they are compatible by checking release notes or testing your setup.
- deprecated Many internal Orbax implementations were refactored into a private `_src` directory. While most public APIs should remain unaffected, some lightly-used public APIs might have become private. This might lead to `ImportError` or `AttributeError` for users relying on such paths.
Install
-
pip install orbax-export -
pip install orbax-export[all]
Imports
- ExportManager
from orbax.export import ExportManager
- JaxModule
from orbax.export import JaxModule
- ServingConfig
from orbax.export import ServingConfig
- orbax.export
import orbax.export
Quickstart
import os
import jax
import jax.numpy as jnp
import tensorflow as tf # Required for SavedModel export
from orbax.export import ExportManager, JaxModule, ServingConfig
# Dummy JAX model and parameters for demonstration
class SimpleJAXModel:
def apply(self, params, inputs):
return params['w'] * inputs + params['b']
model_instance = SimpleJAXModel()
final_model_params_to_save = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
# JAX Apply Function: The core JAX logic for the model's forward pass.
@jax.jit
def jax_model_apply_fn_for_export(params, inputs):
return model_instance.apply(params, inputs)
# Optional: TF Pre-processing Function
def tf_preprocess_fn_for_export(input_tensor: tf.Tensor) -> tf.Tensor:
return tf.cast(input_tensor, tf.float32) / 255.0
# Optional: TF Post-processing Function
def tf_postprocess_fn_for_export(output_tensor: tf.Tensor) -> dict[str, tf.Tensor]:
return {'output': output_tensor}
# Create a JaxModule
jax_module = JaxModule(
apply_fn=jax_model_apply_fn_for_export,
params=final_model_params_to_save,
preprocess_fn=tf_preprocess_fn_for_export,
postprocess_fn=tf_postprocess_fn_for_export
)
# Define serving signatures
serving_signatures = {
'serving_default': ServingConfig(
input_signature=[
tf.TensorSpec(shape=[None, 1], dtype=tf.int32, name='input')
],
output_signature={
'output': tf.TensorSpec(shape=[None, 1], dtype=tf.float32, name='output')
}
)
}
# Define export path
export_path = os.environ.get('ORBAX_EXPORT_PATH', '/tmp/my_jax_model_export')
# Export the model
export_manager = ExportManager(
jax_module,
serving_signatures=serving_signatures
)
export_manager.save(export_path)
print(f"JAX model exported to TensorFlow SavedModel at: {export_path}")
# Basic verification (optional)
loaded_model = tf.saved_model.load(export_path)
input_data = tf.constant([[5]], dtype=tf.int32)
output = loaded_model.signatures['serving_default'](input_data)
print(f"Loaded model output for input {input_data.numpy()}: {output['output'].numpy()}")