Metaflow Torchrun Integration
Metaflow-torchrun is a Python library that provides a `@torchrun` decorator to enable distributed PyTorch training within Metaflow steps. It abstracts away the complexities of launching and managing `torchrun` processes, allowing users to integrate distributed ML workflows seamlessly into their Metaflow flows. The current version is 0.2.1, with relatively frequent updates since its initial release.
Common errors
-
ModuleNotFoundError: No module named 'torchrun'
cause `torch` (which bundles `torchrun`) is not installed or not accessible in the Python environment where the Metaflow step runs.fixInstall the PyTorch library: `pip install torch` (or the appropriate command for your system/CUDA version, e.g., `pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118`). -
TypeError: torchrun() got an unexpected keyword argument 'my_arg'
cause An unknown or misspelled argument was passed directly to the `@torchrun` decorator, or arguments intended for the PyTorch script were passed incorrectly.fixEnsure only valid decorator arguments like `nproc_per_node`, `args`, `module`, etc., are used. For arguments to your PyTorch script, use the `args` parameter: `@torchrun(args=['--my-script-arg', 'value'])`. -
MetaflowException: The @torchrun decorator is only supported when running a flow from a Python file, not interactively.
cause Attempting to run a Metaflow flow with the `@torchrun` decorator from an interactive environment (e.g., Jupyter notebook, IPython console), which is not supported.fixSave your Metaflow flow as a Python file (e.g., `my_flow.py`) and execute it from the command line: `python my_flow.py run`.
Warnings
- gotcha The `@torchrun` decorator is incompatible with interactive notebooks (e.g., Jupyter), requiring Metaflow flows to be run from a standalone Python file.
- gotcha Ensure `torch` is installed and `torchrun` is available in the Python environment where the Metaflow step executes.
- gotcha Arguments intended for your PyTorch script, when passed to the `args` parameter of the `@torchrun` decorator, must be a list of strings.
- gotcha The current implementation of `@torchrun` distributes processes across multiple GPUs/CPUs *on a single node*, not across multiple machines.
Install
-
pip install metaflow-torchrun -
pip install metaflow torch
Imports
- torchrun
from metaflow_torchrun import torchrun
Quickstart
import os
import logging
from metaflow import FlowSpec, step, current
from metaflow_torchrun import torchrun
# Set up basic logging to see output from torchrun
logging.basicConfig(level=logging.INFO)
class MyDistributedFlow(FlowSpec):
@step
def start(self):
print(f"Starting flow {current.flow_name}...")
self.next(self.train_distributed)
# Decorate a step with @torchrun to enable distributed execution
@torchrun(nproc_per_node=2) # Use 2 processes for illustration on a single node
@step
def train_distributed(self):
# Inside a torchrun decorated step, the script is executed by multiple processes.
# Each process will have environment variables like LOCAL_RANK, RANK, WORLD_SIZE.
local_rank = int(os.environ.get("LOCAL_RANK", -1))
global_rank = int(os.environ.get("RANK", -1)) # RANK is the global rank
world_size = int(os.environ.get("WORLD_SIZE", -1))
print(f"Hello from process {global_rank}/{world_size} (local rank {local_rank})\n" +
f" in Metaflow run {current.run_id} step {current.step_name}.")
# In a real scenario, you'd initialize a process group and run DDP here.
# import torch.distributed as dist
# if local_rank != -1:
# dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo", rank=global_rank, world_size=world_size)
# Your PyTorch model training code...
# if local_rank != -1:
# dist.destroy_process_group()
self.next(self.end)
@step
def end(self):
print("Distributed training flow completed!")
if __name__ == "__main__":
# To run this flow:
# 1. Save it as a Python file, e.g., `my_flow.py`
# 2. Run from your terminal: `python my_flow.py run`
# Ensure `metaflow-torchrun`, `metaflow`, and `torch` are installed.
MyDistributedFlow()