{"id":1613,"library":"orbax-checkpoint","title":"Orbax Checkpoint","description":"Orbax Checkpoint provides a robust, asynchronous, and fault-tolerant checkpointing library primarily designed for JAX and Flax models. It allows saving and restoring large-scale machine learning model states and arbitrary data structures efficiently, supporting distributed environments and custom serialization. It is currently at version 0.11.34 and frequently updates to align with JAX/Flax developments.","status":"active","version":"0.11.34","language":"en","source_language":"en","source_url":"https://github.com/google/orbax","tags":["jax","flax","checkpointing","machine-learning","serialization"],"install":[{"cmd":"pip install orbax-checkpoint jax flax","lang":"bash","label":"Typical installation for JAX/Flax users"},{"cmd":"pip install orbax-checkpoint","lang":"bash","label":"Minimal installation"}],"dependencies":[{"reason":"Underlying storage and data serialization, especially for distributed checkpoints.","package":"tensorstore","optional":false},{"reason":"Orbax is designed for JAX ecosystems; required for most practical applications.","package":"jax","optional":true},{"reason":"Commonly used with JAX and Orbax for neural network checkpointing.","package":"flax","optional":true}],"imports":[{"symbol":"orbax.checkpoint","correct":"import orbax.checkpoint as ocp"},{"symbol":"CheckpointManager","correct":"from orbax.checkpoint import CheckpointManager"},{"symbol":"CheckpointManagerOptions","correct":"from orbax.checkpoint import CheckpointManagerOptions"},{"note":"In recent versions, StandardSave/StandardRestore moved directly under orbax.checkpoint instead of orbax.checkpoint.args","wrong":"from orbax.checkpoint.args import StandardSave","symbol":"StandardSave","correct":"from orbax.checkpoint import StandardSave"},{"note":"In recent versions, StandardSave/StandardRestore moved directly under orbax.checkpoint instead of orbax.checkpoint.args","wrong":"from orbax.checkpoint.args import StandardRestore","symbol":"StandardRestore","correct":"from orbax.checkpoint import StandardRestore"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nimport orbax.checkpoint as ocp\nimport os\nimport shutil\n\n# Define a temporary checkpoint directory\nckpt_dir = '/tmp/my_orbax_checkpoint_example'\nif os.path.exists(ckpt_dir):\n    shutil.rmtree(ckpt_dir)\nos.makedirs(ckpt_dir, exist_ok=True)\n\n# 1. Create a CheckpointManager\noptions = ocp.CheckpointManagerOptions(\n    save_interval_steps=1,\n    max_to_keep=3,\n    keep_time_interval_secs=None\n)\nmngr = ocp.CheckpointManager(ckpt_dir, options=options)\n\n# 2. Prepare some data to save\nstep = 0\ndata_to_save = {'params': jnp.array([1.0, 2.0, 3.0])}\nprint(f\"Saving data at step {step}: {data_to_save['params']}\")\n\n# 3. Save the checkpoint\n# Ensure to wrap data with StandardSave for explicit serialization\nmngr.save(step, args=ocp.StandardSave(data_to_save))\nmngr.wait_until_finished() # Ensure save completes\n\n# Simulate more steps and saves\nstep = 1\ndata_to_save = {'params': jnp.array([4.0, 5.0, 6.0])}\nprint(f\"Saving data at step {step}: {data_to_save['params']}\")\nmngr.save(step, args=ocp.StandardSave(data_to_save))\nmngr.wait_until_finished()\n\n# 4. Restore the latest checkpoint\nlatest_step = mngr.latest_step()\nif latest_step is not None:\n    print(f\"\\nRestoring data from latest step: {latest_step}\")\n    # Provide a template for StandardRestore, even if just the expected structure\n    restored_data = mngr.restore(latest_step, args=ocp.StandardRestore(data_to_save))\n    print(f\"Restored data: {restored_data['params']}\")\nelse:\n    print(\"No checkpoint found to restore.\")\n\n# 5. Close the manager\nmngr.close()\n\n# Clean up\nif os.path.exists(ckpt_dir):\n    shutil.rmtree(ckpt_dir)\n","lang":"python","description":"Demonstrates how to initialize a CheckpointManager, save JAX array data, and restore the latest checkpoint. Highlights the use of `ocp.StandardSave` and `ocp.StandardRestore` for explicit serialization arguments and the importance of `wait_until_finished()`."},"warnings":[{"fix":"Consult the latest Orbax documentation and examples for `orbax.checkpoint.args` usage. Data passed to `save()` and `restore()` typically needs to be wrapped, e.g., `args=ocp.StandardSave(data)`.","message":"Serialization API changes, especially with `ocp.StandardSave` and `ocp.StandardRestore`, have occurred across minor versions, requiring explicit wrappers for data.","severity":"breaking","affected_versions":"0.10.x to 0.11.x (and potentially earlier major internal refactors)"},{"fix":"Always call `CheckpointManager.wait_until_finished()` after `save()` calls (or `.commit()` in older APIs) before relying on the checkpoint or exiting the program. Using `with CheckpointManager(...) as mngr:` context manager handles closing and waiting.","message":"CheckpointManager operations are asynchronous. Failing to call `.wait_until_finished()` can lead to incomplete or corrupted checkpoints if the program exits prematurely.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Allow Orbax to manage checkpoint paths and directory structures. Avoid manual file operations within the `ckpt_dir` directly.","message":"Manual modification of checkpoint directories or subfolders can interfere with CheckpointManager's internal state and cleanup logic (`max_to_keep`).","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-09T00:00:00.000Z","next_check":"2026-07-08T00:00:00.000Z"}