Generic API for Pyro Backends
pyro-api provides a generic, backend-agnostic API for probabilistic programming, allowing users to write code that can be executed with different Pyro backends (e.g., Pyro, NumPyro, Funsor) without modification. It aims to abstract common functionalities like distributions, inference primitives, and parameter management. The current version is 0.1.2, and it maintains a slow release cadence, primarily focusing on stability and compatibility.
Warnings
- breaking The `pyro.generic` module was moved out of the `pyro` package into the standalone `pyro-api` package in version 0.1.0. Direct imports like `from pyro.generic import something` will fail.
- gotcha `pyro-api` is an interface and does not provide an implementation itself. You must install a backend library (e.g., `pyro`, `numpyro`, `funsor`) and call `pyroapi.set_backend()` before using any generic API features.
- gotcha When using `pyro-api`, ensure that all operations are performed through the generic API (e.g., `pyroapi.distributions`, `pyroapi.params`). Mixing generic API calls with direct backend-specific calls (e.g., `torch.tensor` directly when using `pyro` backend) can lead to unexpected behavior or break backend portability.
Install
-
pip install pyro-api
Imports
- pyroapi
import pyroapi
- set_backend
from pyroapi import set_backend
- distributions
import pyroapi.distributions as dist
- param
import pyroapi.params as params
- handlers
import pyroapi.handlers as handlers
Quickstart
import pyroapi
import os
# IMPORTANT: pyro-api is an interface. You must install a backend.
# For example, run: pip install pyro
# Or for NumPyro: pip install numpyro
# Set the backend. This line typically goes near the start of your script.
# 'pyro' is a common default, but 'numpyro' or 'funsor' are also options.
# If the chosen backend is not installed, set_backend will raise an error.
try:
backend_name = os.environ.get('PYRO_BACKEND', 'pyro') # Use env var for testing
pyroapi.set_backend(backend_name)
print(f"Pyro-API backend successfully set to: {pyroapi.get_backend()}")
# Now you can use the generic API for distributions, parameters, etc.
import pyroapi.distributions as dist
import pyroapi.params as params
# Example: Define a generic parameter and a distribution
mean_param = params.param("mean_value", 0.0)
std_param = params.param("std_value", 1.0, constraint=dist.constraints.positive)
generic_normal = dist.Normal(mean_param, std_param)
print(f"\nCreated a generic Normal distribution with mean {mean_param.data} and std {std_param.data}")
except ImportError as e:
print(f"Failed to set Pyro-API backend: {e}")
print("Please ensure your chosen backend (e.g., 'pyro' or 'numpyro') is installed.")
except Exception as e:
print(f"An unexpected error occurred: {e}")