{"id":4673,"library":"orbax-export","title":"Orbax Export","description":"Orbax Export is a Python library designed for JAX users, providing utilities to serialize and export JAX models into the TensorFlow SavedModel format. It is a specialized component within the broader Orbax ecosystem, which offers common training utilities for JAX. The library is actively maintained, with its latest version being 0.0.8, released in September 2025, and maintains a regular release cadence to support JAX users.","status":"active","version":"0.0.8","language":"en","source_language":"en","source_url":"https://github.com/google/orbax","tags":["JAX","machine learning","serialization","export","TensorFlow SavedModel","AI/ML"],"install":[{"cmd":"pip install orbax-export","lang":"bash","label":"Basic Installation"},{"cmd":"pip install orbax-export[all]","lang":"bash","label":"Installation with TensorFlow (Recommended for SavedModel export)"}],"dependencies":[{"reason":"Essential for its primary functionality: exporting JAX models to the TensorFlow SavedModel format. Not installed by default.","package":"tensorflow","optional":true},{"reason":"Core library for defining JAX models that are to be exported.","package":"jax","optional":false},{"reason":"Often used in conjunction for checkpointing JAX models before export; compatibility between versions is crucial.","package":"orbax-checkpoint","optional":true}],"imports":[{"symbol":"ExportManager","correct":"from orbax.export import ExportManager"},{"symbol":"JaxModule","correct":"from orbax.export import JaxModule"},{"symbol":"ServingConfig","correct":"from orbax.export import ServingConfig"},{"symbol":"orbax.export","correct":"import orbax.export"}],"quickstart":{"code":"import os\nimport jax\nimport jax.numpy as jnp\nimport tensorflow as tf # Required for SavedModel export\nfrom orbax.export import ExportManager, JaxModule, ServingConfig\n\n# Dummy JAX model and parameters for demonstration\nclass SimpleJAXModel:\n    def apply(self, params, inputs):\n        return params['w'] * inputs + params['b']\n\nmodel_instance = SimpleJAXModel()\nfinal_model_params_to_save = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}\n\n# JAX Apply Function: The core JAX logic for the model's forward pass.\n@jax.jit\ndef jax_model_apply_fn_for_export(params, inputs):\n    return model_instance.apply(params, inputs)\n\n# Optional: TF Pre-processing Function\ndef tf_preprocess_fn_for_export(input_tensor: tf.Tensor) -> tf.Tensor:\n    return tf.cast(input_tensor, tf.float32) / 255.0\n\n# Optional: TF Post-processing Function\ndef tf_postprocess_fn_for_export(output_tensor: tf.Tensor) -> dict[str, tf.Tensor]:\n    return {'output': output_tensor}\n\n# Create a JaxModule\njax_module = JaxModule(\n    apply_fn=jax_model_apply_fn_for_export,\n    params=final_model_params_to_save,\n    preprocess_fn=tf_preprocess_fn_for_export,\n    postprocess_fn=tf_postprocess_fn_for_export\n)\n\n# Define serving signatures\nserving_signatures = {\n    'serving_default': ServingConfig(\n        input_signature=[\n            tf.TensorSpec(shape=[None, 1], dtype=tf.int32, name='input')\n        ],\n        output_signature={\n            'output': tf.TensorSpec(shape=[None, 1], dtype=tf.float32, name='output')\n        }\n    )\n}\n\n# Define export path\nexport_path = os.environ.get('ORBAX_EXPORT_PATH', '/tmp/my_jax_model_export')\n\n# Export the model\nexport_manager = ExportManager(\n    jax_module,\n    serving_signatures=serving_signatures\n)\nexport_manager.save(export_path)\n\nprint(f\"JAX model exported to TensorFlow SavedModel at: {export_path}\")\n\n# Basic verification (optional)\nloaded_model = tf.saved_model.load(export_path)\ninput_data = tf.constant([[5]], dtype=tf.int32)\noutput = loaded_model.signatures['serving_default'](input_data)\nprint(f\"Loaded model output for input {input_data.numpy()}: {output['output'].numpy()}\")\n","lang":"python","description":"This quickstart demonstrates how to define a simple JAX model, wrap it with `JaxModule`, configure serving signatures with `ServingConfig`, and then export it to the TensorFlow SavedModel format using `ExportManager.save()`. It includes optional TensorFlow pre- and post-processing functions that will be integrated into the SavedModel graph. Ensure `tensorflow` is installed to run this example."},"warnings":[{"fix":"Install TensorFlow manually: `pip install tensorflow` or use the `[all]` extra: `pip install orbax-export[all]`.","message":"Orbax Export requires TensorFlow for its core functionality (exporting to SavedModel), but TensorFlow is *not* installed by default. Users must explicitly install it using `pip install tensorflow` or by installing `orbax-export` with the `[all]` extra: `pip install orbax-export[all]`.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Ensure you install `orbax-export` directly via `pip install orbax-export` instead of relying on the legacy `orbax` package.","message":"The original `orbax` PyPI package (frozen at 0.1.6/0.1.9) is no longer the primary installation target. Users should directly install specific sub-packages like `orbax-export` or `orbax-checkpoint` to avoid dependency bloat. While existing `from orbax import export` statements may still work due to namespace preservation, installing the specific sub-package is the recommended approach.","severity":"breaking","affected_versions":"Users migrating from `orbax<0.1.9` to `orbax-export`"},{"fix":"Consult the Orbax GitHub repository or documentation for recommended compatible versions of `orbax-export` and `orbax-checkpoint`. Update both packages to their latest compatible releases.","message":"There can be version incompatibilities between `orbax-export` and `orbax-checkpoint` (e.g., `orbax-export 0.0.5` was incompatible with `orbax-checkpoint 0.9.0`). When using both, always ensure they are compatible by checking release notes or testing your setup.","severity":"gotcha","affected_versions":"All versions when used with `orbax-checkpoint`"},{"fix":"If encountering issues, refer to the official Orbax documentation and API reference for the correct, currently public API paths. Contact the Orbax team if a needed API has become private.","message":"Many internal Orbax implementations were refactored into a private `_src` directory. While most public APIs should remain unaffected, some lightly-used public APIs might have become private. This might lead to `ImportError` or `AttributeError` for users relying on such paths.","severity":"deprecated","affected_versions":"Versions released after 2024-10-01"}],"env_vars":null,"last_verified":"2026-04-12T00:00:00.000Z","next_check":"2026-07-11T00:00:00.000Z"}