Metaflow Torchrun Integration

0.2.1 · active · verified Fri Apr 17

Metaflow-torchrun is a Python library that provides a `@torchrun` decorator to enable distributed PyTorch training within Metaflow steps. It abstracts away the complexities of launching and managing `torchrun` processes, allowing users to integrate distributed ML workflows seamlessly into their Metaflow flows. The current version is 0.2.1, with relatively frequent updates since its initial release.

Common errors

Warnings

Install

Imports

Quickstart

This example demonstrates a basic Metaflow flow using the `@torchrun` decorator. Save this code as a Python file (e.g., `my_flow.py`) and run it from your terminal using `python my_flow.py run`. It will launch 2 parallel processes on the local machine within the `train_distributed` step, each printing its rank information. Ensure `metaflow-torchrun`, `metaflow`, and `torch` are installed in your environment.

import os
import logging
from metaflow import FlowSpec, step, current
from metaflow_torchrun import torchrun

# Set up basic logging to see output from torchrun
logging.basicConfig(level=logging.INFO)

class MyDistributedFlow(FlowSpec):

    @step
    def start(self):
        print(f"Starting flow {current.flow_name}...")
        self.next(self.train_distributed)

    # Decorate a step with @torchrun to enable distributed execution
    @torchrun(nproc_per_node=2) # Use 2 processes for illustration on a single node
    @step
    def train_distributed(self):
        # Inside a torchrun decorated step, the script is executed by multiple processes.
        # Each process will have environment variables like LOCAL_RANK, RANK, WORLD_SIZE.
        local_rank = int(os.environ.get("LOCAL_RANK", -1))
        global_rank = int(os.environ.get("RANK", -1)) # RANK is the global rank
        world_size = int(os.environ.get("WORLD_SIZE", -1))

        print(f"Hello from process {global_rank}/{world_size} (local rank {local_rank})\n" +
              f" in Metaflow run {current.run_id} step {current.step_name}.")

        # In a real scenario, you'd initialize a process group and run DDP here.
        # import torch.distributed as dist
        # if local_rank != -1:
        #     dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo", rank=global_rank, world_size=world_size)
        # Your PyTorch model training code...
        # if local_rank != -1:
        #     dist.destroy_process_group()

        self.next(self.end)

    @step
    def end(self):
        print("Distributed training flow completed!")

if __name__ == "__main__":
    # To run this flow:
    # 1. Save it as a Python file, e.g., `my_flow.py`
    # 2. Run from your terminal: `python my_flow.py run`
    # Ensure `metaflow-torchrun`, `metaflow`, and `torch` are installed.
    MyDistributedFlow()

view raw JSON →