Orbax Checkpoint
Orbax Checkpoint provides a robust, asynchronous, and fault-tolerant checkpointing library primarily designed for JAX and Flax models. It allows saving and restoring large-scale machine learning model states and arbitrary data structures efficiently, supporting distributed environments and custom serialization. It is currently at version 0.11.34 and frequently updates to align with JAX/Flax developments.
Warnings
- breaking Serialization API changes, especially with `ocp.StandardSave` and `ocp.StandardRestore`, have occurred across minor versions, requiring explicit wrappers for data.
- gotcha CheckpointManager operations are asynchronous. Failing to call `.wait_until_finished()` can lead to incomplete or corrupted checkpoints if the program exits prematurely.
- gotcha Manual modification of checkpoint directories or subfolders can interfere with CheckpointManager's internal state and cleanup logic (`max_to_keep`).
Install
-
pip install orbax-checkpoint jax flax -
pip install orbax-checkpoint
Imports
- orbax.checkpoint
import orbax.checkpoint as ocp
- CheckpointManager
from orbax.checkpoint import CheckpointManager
- CheckpointManagerOptions
from orbax.checkpoint import CheckpointManagerOptions
- StandardSave
from orbax.checkpoint import StandardSave
- StandardRestore
from orbax.checkpoint import StandardRestore
Quickstart
import jax
import jax.numpy as jnp
import orbax.checkpoint as ocp
import os
import shutil
# Define a temporary checkpoint directory
ckpt_dir = '/tmp/my_orbax_checkpoint_example'
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)
os.makedirs(ckpt_dir, exist_ok=True)
# 1. Create a CheckpointManager
options = ocp.CheckpointManagerOptions(
save_interval_steps=1,
max_to_keep=3,
keep_time_interval_secs=None
)
mngr = ocp.CheckpointManager(ckpt_dir, options=options)
# 2. Prepare some data to save
step = 0
data_to_save = {'params': jnp.array([1.0, 2.0, 3.0])}
print(f"Saving data at step {step}: {data_to_save['params']}")
# 3. Save the checkpoint
# Ensure to wrap data with StandardSave for explicit serialization
mngr.save(step, args=ocp.StandardSave(data_to_save))
mngr.wait_until_finished() # Ensure save completes
# Simulate more steps and saves
step = 1
data_to_save = {'params': jnp.array([4.0, 5.0, 6.0])}
print(f"Saving data at step {step}: {data_to_save['params']}")
mngr.save(step, args=ocp.StandardSave(data_to_save))
mngr.wait_until_finished()
# 4. Restore the latest checkpoint
latest_step = mngr.latest_step()
if latest_step is not None:
print(f"\nRestoring data from latest step: {latest_step}")
# Provide a template for StandardRestore, even if just the expected structure
restored_data = mngr.restore(latest_step, args=ocp.StandardRestore(data_to_save))
print(f"Restored data: {restored_data['params']}")
else:
print("No checkpoint found to restore.")
# 5. Close the manager
mngr.close()
# Clean up
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)