ML Diagnostics Python SDK
The `google-cloud-mldiagnostics` library is the Python SDK for Google Cloud's ML Diagnostics platform. It integrates with machine learning workloads to collect and manage workload metrics, configurations, and profiles, and enables programmatic and on-demand profile capture. It helps users create and monitor machine learning runs, deploy managed XProf resources for performance profiling, and visualize various workload aspects on Google Cloud. The library is actively maintained, with frequent updates aligning with new features and improvements in Google Cloud services.
Common errors
-
google.api_core.exceptions.PermissionDenied: 403 Permission denied to access project...
cause The Python environment running the code does not have the necessary Google Cloud authentication credentials or the authenticated principal lacks permissions to create/manage ML Diagnostics resources in the specified project.fixEnsure `GOOGLE_APPLICATION_CREDENTIALS` environment variable points to a valid service account key JSON file with appropriate roles (e.g., 'ML Diagnostics Editor' or 'Owner') for the project, or run `gcloud auth application-default login` and ensure the authenticated user has sufficient permissions. -
ModuleNotFoundError: No module named 'google_cloud_mldiagnostics'
cause The `google-cloud-mldiagnostics` library is not installed in the Python environment.fixInstall the library using `pip install google-cloud-mldiagnostics`. -
TypeError: create_run() missing 1 required positional argument: 'project_id'
cause When calling `machinelearning_run.create_run()`, the `project_id` argument was not provided. All ML Diagnostics operations require a target Google Cloud project.fixAlways pass the `project_id` parameter to `create_run()`. It's recommended to retrieve it from an environment variable like `GCP_PROJECT_ID` or explicitly provide it, e.g., `machinelearning_run.create_run(project_id='your-project-id', ...)`.
Warnings
- breaking The generic `google-cloud` package is deprecated. Users should install product-specific packages like `google-cloud-mldiagnostics` instead of the umbrella package to avoid issues and ensure up-to-date functionality.
- gotcha The ML Diagnostics SDK for Python currently only officially supports JAX on TPUs. Using it with other frameworks (e.g., TensorFlow, PyTorch) or hardware (e.g., GPUs, CPUs) might not be fully supported or may have limitations.
- gotcha To route SDK logs, metrics, and configuration information to Google Cloud Logging, you must explicitly install and configure the `google-cloud-logging` library in your application. Without this, SDK output will only go to standard Python logging, not Cloud Logging.
- gotcha Authentication to Google Cloud services is required. Incorrect or missing authentication credentials (e.g., `GOOGLE_APPLICATION_CREDENTIALS` not set, or `gcloud auth application-default login` not run) will lead to permission errors when the SDK attempts to interact with Google Cloud APIs.
Install
-
pip install google-cloud-mldiagnostics -
pip install google-cloud-logging
Imports
- machinelearning_run
from google_cloud_mldiagnostics import machinelearning_run
- metrics
from google_cloud_mldiagnostics import metrics
- xprof
from google_cloud_mldiagnostics import xprof
- MLRun
from google_cloud_mldiagnostics.proto.diagnostics import MLRun
- MetricType
from google_cloud_mldiagnostics.proto.diagnostics import MetricType
Quickstart
import os
import logging
import google.cloud.logging
from google_cloud_mldiagnostics import machinelearning_run
from google_cloud_mldiagnostics import metrics
from google_cloud_mldiagnostics import xprof
from google_cloud_mldiagnostics.proto.diagnostics import MetricType
# Set up Cloud Logging (recommended)
logging_client = google.cloud.logging.Client()
logging_client.setup_logging()
logging.info("Cloud Logging is set up.")
PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'your-gcp-project-id')
# Ensure GOOGLE_APPLICATION_CREDENTIALS is set or authenticated via gcloud CLI
def run_ml_diagnostics_example():
print(f"Using GCP Project ID: {PROJECT_ID}")
# 1. Create a machine learning run
# The SDK automatically generates a unique run_id if not provided.
run_name = machinelearning_run.create_run(
project_id=PROJECT_ID,
experiment_name="my-first-experiment",
display_name="my-training-run"
)
print(f"Created ML Run: {run_name}")
# 2. Record metrics
metrics.record(MetricType.LOSS, 0.5, step=1, run_name=run_name)
metrics.record(MetricType.ACCURACY, 0.8, step=1, run_name=run_name)
print("Recorded initial metrics.")
metrics.record(MetricType.LOSS, 0.2, step=10, run_name=run_name)
metrics.record(MetricType.ACCURACY, 0.95, step=10, run_name=run_name)
print("Recorded updated metrics.")
# 3. Write configurations (example)
machinelearning_run.write_config(run_name, {"learning_rate": 0.01, "batch_size": 32})
print("Wrote run configurations.")
# Example of capturing a profile (requires XProf server running in your workload)
# For on-demand capture, ensure xprof.start_server() is called in your ML workload.
# xprof.capture_profile(run_name, 'gs://your-bucket/profiles', duration_ms=10000)
# print("Attempted to capture profile.")
print("ML Diagnostics example completed. Check Google Cloud Console for 'my-training-run'.")
if __name__ == '__main__':
run_ml_diagnostics_example()