SageMaker MLflow Plugin

0.2.0 · active · verified Fri Apr 10

sagemaker-mlflow is an AWS plugin that enables MLflow to use SageMaker as its backend for experiment tracking, allowing users to leverage SageMaker's managed infrastructure for MLflow tracking servers and artifact storage. The current version is 0.2.0, with releases occurring as new features or bug fixes are introduced, typically driven by community contributions and AWS service enhancements.

Warnings

Install

Imports

Quickstart

This example demonstrates how to configure MLflow to use SageMaker for experiment tracking. It first attempts to retrieve the SageMaker MLflow tracking URI and then logs a simple run. It also shows an alternative for SageMaker Studio users to globally enable tracking. Before running, ensure you have an active SageMaker MLflow Tracking Server and appropriate IAM permissions.

import sagemaker_mlflow
import mlflow
import os

# Ensure MLflow is installed alongside sagemaker-mlflow

# Option 1: Explicitly set the tracking URI
# This assumes a SageMaker MLflow Tracking Server is already provisioned.
# You might need to set AWS_REGION environment variable if not in SageMaker Studio.
# os.environ['AWS_REGION'] = 'us-east-1'

try:
    tracking_uri = sagemaker_mlflow.get_sagemaker_tracking_uri()
    mlflow.set_tracking_uri(tracking_uri)
    print(f"MLflow Tracking URI set to: {tracking_uri}")

    with mlflow.start_run() as run:
        mlflow.log_param("alpha", 0.5)
        mlflow.log_metric("rmse", 0.75)
        print(f"Logged run with ID: {run.info.run_id}")

except Exception as e:
    print(f"Failed to set MLflow Tracking URI or start run: {e}")
    print("Ensure a SageMaker MLflow Tracking Server is deployed and your IAM role has permissions.")

# Option 2 (recommended for SageMaker Studio environments):
# sagemaker_mlflow.enable_global_sagemaker_mlflow_tracking()
# print("Global SageMaker MLflow tracking enabled.")
# with mlflow.start_run() as run:
#    mlflow.log_param("beta", 0.1)
#    mlflow.log_metric("accuracy", 0.99)
#    print(f"Logged another run with ID: {run.info.run_id}")

view raw JSON →