Metaflow Checkpoint

0.2.10 · active · verified Thu Apr 16

Metaflow-checkpoint is an experimental extension for Metaflow that provides in-task checkpointing capabilities. It allows users to periodically save the progress of long-running Metaflow steps, such as machine learning model training, ensuring recovery from failures without losing significant work. The library, currently at version 0.2.10, is released as an independent extension to core Metaflow.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates a Metaflow flow using the `@checkpoint` decorator. The `start` step simulates a long-running, flaky process that increments a counter. It saves the counter value to `current.checkpoint.directory` and calls `current.checkpoint.save()` periodically. Upon restart (due to `@retry` or `resume` command), it loads the last saved counter using `current.checkpoint.is_loaded` and `current.checkpoint.directory`. The `load_policy='eager'` allows checkpoints to be reused across different runs, aiding iterative development. Run with `python your_flow.py run` and try `python your_flow.py resume start` after an interruption.

import os
import random
from metaflow import FlowSpec, step, current, checkpoint, retry

class CheckpointCounterFlow(FlowSpec):
    @retry(times=2, minutes_between_retries=1)
    @checkpoint(load_policy='eager') # Use 'eager' for development across runs
    @step
    def start(self):
        self.counter = 0
        if current.checkpoint.is_loaded:
            print(f"Resuming from checkpoint. Counter was {self.counter}")
            with open(os.path.join(current.checkpoint.directory, 'counter'), 'r') as f:
                self.counter = int(f.read())
            print(f"Successfully loaded counter: {self.counter}")
        else:
            print("Starting from scratch.")

        for i in range(5):
            self.counter += 1
            print(f"Processing iteration {i+1}, counter is {self.counter}")
            # Save progress periodically
            with open(os.path.join(current.checkpoint.directory, 'counter'), 'w') as f:
                f.write(str(self.counter))
            current.checkpoint.save()

            # Simulate a flaky operation
            if random.random() < 0.3:
                raise Exception("Simulated failure!")

        self.next(self.end)

    @step
    def end(self):
        print(f"Flow finished. Final counter value: {self.counter}")

if __name__ == '__main__':
    CheckpointCounterFlow()

view raw JSON →