Orbax Checkpoint
Orbax Checkpoint provides a robust, asynchronous, and fault-tolerant checkpointing library primarily designed for JAX and Flax models. It allows saving and restoring large-scale machine learning model states and arbitrary data structures efficiently, supporting distributed environments and custom serialization. It is currently at version 0.11.34 and frequently updates to align with JAX/Flax developments.
Common errors
-
ModuleNotFoundError: No module named 'orbax.checkpoint'
cause The `orbax-checkpoint` library or its dependencies are not installed, or the Python environment where the code is being run does not have access to the installed package.fixEnsure `orbax-checkpoint` is installed in your active Python environment. If using a virtual environment, activate it before installing. ```bash pip install orbax-checkpoint # Or, for a specific version: pip install orbax-checkpoint==0.11.34 ``` -
WARNING:absl:Item "state" was found in the checkpoint, but could not be restored. Please provide a `CheckpointHandlerRegistry`, or call `restore` with an appropriate `CheckpointArgs` subclass.
cause This warning (which can lead to restoration issues) indicates that the `CheckpointManager` is trying to restore an item without knowing the correct handler or arguments to use, often due to API changes where `item_names` and `args` are now preferred over the deprecated `items` argument, or a missing `CheckpointHandlerRegistry` configuration.fixWhen using `CheckpointManager.restore()`, explicitly provide `CheckpointArgs` (e.g., `ocp.args.StandardRestore` or `ocp.args.Composite`) or configure a `handler_registry` during `CheckpointManager` initialization. ```python import orbax.checkpoint as ocp import jax # Assuming 'directory' is your checkpoint path and 'abstract_state' defines the expected structure mngr = ocp.CheckpointManager(directory) restored_state = mngr.restore( mngr.latest_step(), args=ocp.args.StandardRestore(abstract_state) # Use appropriate CheckpointArgs ) # Or, configure with a handler registry at initialization (for multiple items): # registry = ocp.handlers.DefaultCheckpointHandlerRegistry() # registry.add('state', ocp.StandardCheckpointHandler()) # mngr = ocp.CheckpointManager(directory, handler_registry=registry) # restored_state = mngr.restore(mngr.latest_step()) ``` -
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices()
cause This error typically occurs during restoration when the `Sharding` information saved in the checkpoint specifies a device or sharding configuration (e.g., a specific CUDA device) that is not available or does not match the `jax.local_devices()` in the current environment. This can happen when moving checkpoints between different hardware setups (e.g., from GPU to CPU, or different numbers of devices).fixEnsure your JAX environment is initialized correctly for the target devices. If restoring to a different topology, you might need to specify the `sharding` explicitly in `ArrayRestoreArgs` or `ArrayOptions.Loading` (for v1 API) to remap sharding, or update `orbax-checkpoint` to the latest version as this can sometimes be resolved by library updates. ```python import orbax.checkpoint as ocp import jax # Ensure JAX devices are correctly initialized # jax.distributed.initialize() # For multi-host setups # When restoring, you might need to re-specify the target sharding sharding = jax.sharding.NamedSharding(jax.devices(), jax.sharding.PartitionSpec('dp', 'mp')) # Example for a new sharding abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, target_state_with_new_sharding) # Using ArrayRestoreArgs or ArrayOptions.Loading to override sharding # Example with Checkpointer ckptr = ocp.StandardCheckpointer() restored = ckptr.restore( 'path/to/checkpoint', args=ocp.args.StandardRestore(abstract_state) # If abstract_state includes target sharding ) ``` -
Requested shape: (32,) is not compatible with the stored shape: (16,). Truncating/padding is disabled.
cause This error occurs when attempting to restore an array from a checkpoint into a target array (or abstract shape) with a different shape, and Orbax's strict restoration policy (default) prevents automatic resizing.fixTo allow restoration with different shapes, enable padding or truncation by setting `strict=False` in `ArrayRestoreArgs` (for v0 API) or `enable_padding_and_truncation=True` in `ArrayOptions.Loading` (for v1 API) for the specific array(s) or generally. ```python import orbax.checkpoint as ocp import jax import numpy as np # Example of how to enable padding/truncation during restore path = '/tmp/my-checkpoint' ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) # Assuming 'abstract_state' has the *desired* target shape, e.g., (32,) # If target shape is different, use ArrayRestoreArgs or similar from orbax.checkpoint.checkpoint_args import ArrayRestoreArgs # Create a dummy abstract state with the desired (different) shape restored_abstract_state = jax.tree_util.tree_map( lambda x: jax.ShapeDtypeStruct((32,), x.dtype), # Example: new shape (32,) jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, original_state) # Use original state structure for mapping ) # Restore with padding/truncation enabled restored = ckptr.restore( path, args=ocp.args.StandardRestore(restored_abstract_state, options=ocp.ArrayOptions.restore(enable_padding_and_truncation=True)) # For v1 API ) # For v0 API (older): # restored = ckptr.restore(path / '1', args=ocp.args.ArrayRestoreArgs(target_shape_dtype_struct, strict=False)) ```
Warnings
- breaking Serialization API changes, especially with `ocp.StandardSave` and `ocp.StandardRestore`, have occurred across minor versions, requiring explicit wrappers for data.
- gotcha CheckpointManager operations are asynchronous. Failing to call `.wait_until_finished()` can lead to incomplete or corrupted checkpoints if the program exits prematurely.
- gotcha Manual modification of checkpoint directories or subfolders can interfere with CheckpointManager's internal state and cleanup logic (`max_to_keep`).
Install
-
pip install orbax-checkpoint jax flax -
pip install orbax-checkpoint
Imports
- orbax.checkpoint
import orbax.checkpoint as ocp
- CheckpointManager
from orbax.checkpoint import CheckpointManager
- CheckpointManagerOptions
from orbax.checkpoint import CheckpointManagerOptions
- StandardSave
from orbax.checkpoint.args import StandardSave
from orbax.checkpoint import StandardSave
- StandardRestore
from orbax.checkpoint.args import StandardRestore
from orbax.checkpoint import StandardRestore
Quickstart
import jax
import jax.numpy as jnp
import orbax.checkpoint as ocp
import os
import shutil
# Define a temporary checkpoint directory
ckpt_dir = '/tmp/my_orbax_checkpoint_example'
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)
os.makedirs(ckpt_dir, exist_ok=True)
# 1. Create a CheckpointManager
options = ocp.CheckpointManagerOptions(
save_interval_steps=1,
max_to_keep=3,
keep_time_interval_secs=None
)
mngr = ocp.CheckpointManager(ckpt_dir, options=options)
# 2. Prepare some data to save
step = 0
data_to_save = {'params': jnp.array([1.0, 2.0, 3.0])}
print(f"Saving data at step {step}: {data_to_save['params']}")
# 3. Save the checkpoint
# Ensure to wrap data with StandardSave for explicit serialization
mngr.save(step, args=ocp.StandardSave(data_to_save))
mngr.wait_until_finished() # Ensure save completes
# Simulate more steps and saves
step = 1
data_to_save = {'params': jnp.array([4.0, 5.0, 6.0])}
print(f"Saving data at step {step}: {data_to_save['params']}")
mngr.save(step, args=ocp.StandardSave(data_to_save))
mngr.wait_until_finished()
# 4. Restore the latest checkpoint
latest_step = mngr.latest_step()
if latest_step is not None:
print(f"\nRestoring data from latest step: {latest_step}")
# Provide a template for StandardRestore, even if just the expected structure
restored_data = mngr.restore(latest_step, args=ocp.StandardRestore(data_to_save))
print(f"Restored data: {restored_data['params']}")
else:
print("No checkpoint found to restore.")
# 5. Close the manager
mngr.close()
# Clean up
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)