SageMaker MLflow Plugin
sagemaker-mlflow is an AWS plugin that enables MLflow to use SageMaker as its backend for experiment tracking, allowing users to leverage SageMaker's managed infrastructure for MLflow tracking servers and artifact storage. The current version is 0.2.0, with releases occurring as new features or bug fixes are introduced, typically driven by community contributions and AWS service enhancements.
Warnings
- gotcha The sagemaker-mlflow plugin connects to an existing SageMaker MLflow Tracking Server. It does not provision or deploy this server for you. Users must first deploy an MLflow Tracking Server in SageMaker (e.g., via SageMaker Studio applications) before using this plugin.
- gotcha Proper IAM permissions are crucial. The AWS identity (user or role) executing the MLflow code must have permissions to list and connect to SageMaker MLflow Tracking Servers, as well as necessary S3 permissions for artifact storage. Missing permissions often result in connection errors.
- gotcha Cross-account access to a SageMaker MLflow Tracking Server is only supported from `sagemaker-mlflow` version 0.2.0 onwards. If you are using an older version, the tracking server and your MLflow client must reside in the same AWS account.
Install
-
pip install sagemaker-mlflow
Imports
- sagemaker_mlflow
import sagemaker_mlflow
Quickstart
import sagemaker_mlflow
import mlflow
import os
# Ensure MLflow is installed alongside sagemaker-mlflow
# Option 1: Explicitly set the tracking URI
# This assumes a SageMaker MLflow Tracking Server is already provisioned.
# You might need to set AWS_REGION environment variable if not in SageMaker Studio.
# os.environ['AWS_REGION'] = 'us-east-1'
try:
tracking_uri = sagemaker_mlflow.get_sagemaker_tracking_uri()
mlflow.set_tracking_uri(tracking_uri)
print(f"MLflow Tracking URI set to: {tracking_uri}")
with mlflow.start_run() as run:
mlflow.log_param("alpha", 0.5)
mlflow.log_metric("rmse", 0.75)
print(f"Logged run with ID: {run.info.run_id}")
except Exception as e:
print(f"Failed to set MLflow Tracking URI or start run: {e}")
print("Ensure a SageMaker MLflow Tracking Server is deployed and your IAM role has permissions.")
# Option 2 (recommended for SageMaker Studio environments):
# sagemaker_mlflow.enable_global_sagemaker_mlflow_tracking()
# print("Global SageMaker MLflow tracking enabled.")
# with mlflow.start_run() as run:
# mlflow.log_param("beta", 0.1)
# mlflow.log_metric("accuracy", 0.99)
# print(f"Logged another run with ID: {run.info.run_id}")