Pathways-on-Cloud Utilities
Pathwaysutils is a Python package providing essential utilities and tools to streamline the deployment and execution of JAX workloads on the Pathways on Cloud architecture. It simplifies large-scale machine learning computations across multiple TPU slices by handling cloud-specific adaptations. Pathways is used internally at Google for models like Gemini, and this package brings similar benefits to Google Cloud customers. The current version is 0.1.7, and it is actively developed with releases tightly coupled to JAX versions.
Warnings
- breaking Pathways on Cloud is currently in Preview. Access requires contacting your Google Cloud account representative. Without access, the utilities will not function as intended.
- gotcha Pathways releases are tightly coupled with JAX versions. Mismatched versions can lead to compatibility and stability issues.
- gotcha Calling `pathwaysutils.initialize()` disables the standard JAX compilation cache. This is intentional for Pathways workloads but might affect expectations if you rely on the JAX cache elsewhere.
- gotcha A Pathways cluster can only maintain a session with one client at a time. Attempts by multiple clients to connect simultaneously will result in connection errors for subsequent clients.
- gotcha Errors after importing `pathwaysutils` might be caused by outdated Flask or Werkzeug versions. Upgrading these packages can sometimes resolve the issue, but may introduce conflicts with other dependencies.
Install
-
pip install pathwaysutils
Imports
- pathwaysutils
import pathwaysutils
Quickstart
import pathwaysutils
import jax
# Pathways on Cloud is currently in Preview and requires access.
# Ensure your environment variables are correctly configured for Pathways on Cloud.
# Example: os.environ['JAX_PLATFORMS'] = 'proxy'
# Initialize pathwaysutils to configure the JAX backend for Pathways on Cloud.
# This also registers a custom ArrayHandler for checkpointing and disables JAX's compilation cache.
pathwaysutils.initialize()
# Verify JAX devices are accessible through the Pathways backend
print(f"JAX devices available via Pathways: {jax.devices()}")
# Further JAX workload execution would follow here.