Orbax Checkpoint

0.11.34 · active · verified Thu Apr 09

Orbax Checkpoint provides a robust, asynchronous, and fault-tolerant checkpointing library primarily designed for JAX and Flax models. It allows saving and restoring large-scale machine learning model states and arbitrary data structures efficiently, supporting distributed environments and custom serialization. It is currently at version 0.11.34 and frequently updates to align with JAX/Flax developments.

Warnings

Install

Imports

Quickstart

Demonstrates how to initialize a CheckpointManager, save JAX array data, and restore the latest checkpoint. Highlights the use of `ocp.StandardSave` and `ocp.StandardRestore` for explicit serialization arguments and the importance of `wait_until_finished()`.

import jax
import jax.numpy as jnp
import orbax.checkpoint as ocp
import os
import shutil

# Define a temporary checkpoint directory
ckpt_dir = '/tmp/my_orbax_checkpoint_example'
if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)
os.makedirs(ckpt_dir, exist_ok=True)

# 1. Create a CheckpointManager
options = ocp.CheckpointManagerOptions(
    save_interval_steps=1,
    max_to_keep=3,
    keep_time_interval_secs=None
)
mngr = ocp.CheckpointManager(ckpt_dir, options=options)

# 2. Prepare some data to save
step = 0
data_to_save = {'params': jnp.array([1.0, 2.0, 3.0])}
print(f"Saving data at step {step}: {data_to_save['params']}")

# 3. Save the checkpoint
# Ensure to wrap data with StandardSave for explicit serialization
mngr.save(step, args=ocp.StandardSave(data_to_save))
mngr.wait_until_finished() # Ensure save completes

# Simulate more steps and saves
step = 1
data_to_save = {'params': jnp.array([4.0, 5.0, 6.0])}
print(f"Saving data at step {step}: {data_to_save['params']}")
mngr.save(step, args=ocp.StandardSave(data_to_save))
mngr.wait_until_finished()

# 4. Restore the latest checkpoint
latest_step = mngr.latest_step()
if latest_step is not None:
    print(f"\nRestoring data from latest step: {latest_step}")
    # Provide a template for StandardRestore, even if just the expected structure
    restored_data = mngr.restore(latest_step, args=ocp.StandardRestore(data_to_save))
    print(f"Restored data: {restored_data['params']}")
else:
    print("No checkpoint found to restore.")

# 5. Close the manager
mngr.close()

# Clean up
if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)

view raw JSON →