Orbax Export

0.0.8 · active · verified Sun Apr 12

Orbax Export is a Python library designed for JAX users, providing utilities to serialize and export JAX models into the TensorFlow SavedModel format. It is a specialized component within the broader Orbax ecosystem, which offers common training utilities for JAX. The library is actively maintained, with its latest version being 0.0.8, released in September 2025, and maintains a regular release cadence to support JAX users.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a simple JAX model, wrap it with `JaxModule`, configure serving signatures with `ServingConfig`, and then export it to the TensorFlow SavedModel format using `ExportManager.save()`. It includes optional TensorFlow pre- and post-processing functions that will be integrated into the SavedModel graph. Ensure `tensorflow` is installed to run this example.

import os
import jax
import jax.numpy as jnp
import tensorflow as tf # Required for SavedModel export
from orbax.export import ExportManager, JaxModule, ServingConfig

# Dummy JAX model and parameters for demonstration
class SimpleJAXModel:
    def apply(self, params, inputs):
        return params['w'] * inputs + params['b']

model_instance = SimpleJAXModel()
final_model_params_to_save = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}

# JAX Apply Function: The core JAX logic for the model's forward pass.
@jax.jit
def jax_model_apply_fn_for_export(params, inputs):
    return model_instance.apply(params, inputs)

# Optional: TF Pre-processing Function
def tf_preprocess_fn_for_export(input_tensor: tf.Tensor) -> tf.Tensor:
    return tf.cast(input_tensor, tf.float32) / 255.0

# Optional: TF Post-processing Function
def tf_postprocess_fn_for_export(output_tensor: tf.Tensor) -> dict[str, tf.Tensor]:
    return {'output': output_tensor}

# Create a JaxModule
jax_module = JaxModule(
    apply_fn=jax_model_apply_fn_for_export,
    params=final_model_params_to_save,
    preprocess_fn=tf_preprocess_fn_for_export,
    postprocess_fn=tf_postprocess_fn_for_export
)

# Define serving signatures
serving_signatures = {
    'serving_default': ServingConfig(
        input_signature=[
            tf.TensorSpec(shape=[None, 1], dtype=tf.int32, name='input')
        ],
        output_signature={
            'output': tf.TensorSpec(shape=[None, 1], dtype=tf.float32, name='output')
        }
    )
}

# Define export path
export_path = os.environ.get('ORBAX_EXPORT_PATH', '/tmp/my_jax_model_export')

# Export the model
export_manager = ExportManager(
    jax_module,
    serving_signatures=serving_signatures
)
export_manager.save(export_path)

print(f"JAX model exported to TensorFlow SavedModel at: {export_path}")

# Basic verification (optional)
loaded_model = tf.saved_model.load(export_path)
input_data = tf.constant([[5]], dtype=tf.int32)
output = loaded_model.signatures['serving_default'](input_data)
print(f"Loaded model output for input {input_data.numpy()}: {output['output'].numpy()}")

view raw JSON →