Generic API for Pyro Backends

0.1.2 · active · verified Sat Apr 11

pyro-api provides a generic, backend-agnostic API for probabilistic programming, allowing users to write code that can be executed with different Pyro backends (e.g., Pyro, NumPyro, Funsor) without modification. It aims to abstract common functionalities like distributions, inference primitives, and parameter management. The current version is 0.1.2, and it maintains a slow release cadence, primarily focusing on stability and compatibility.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to set a backend for `pyro-api` and access its generic modules for distributions and parameters. Remember, `pyro-api` is an interface; you must install a concrete backend like `pyro` or `numpyro` for it to function correctly. The `set_backend()` call typically needs to happen early in your application's lifecycle.

import pyroapi
import os

# IMPORTANT: pyro-api is an interface. You must install a backend.
# For example, run: pip install pyro
# Or for NumPyro: pip install numpyro

# Set the backend. This line typically goes near the start of your script.
# 'pyro' is a common default, but 'numpyro' or 'funsor' are also options.
# If the chosen backend is not installed, set_backend will raise an error.
try:
    backend_name = os.environ.get('PYRO_BACKEND', 'pyro') # Use env var for testing
    pyroapi.set_backend(backend_name)
    print(f"Pyro-API backend successfully set to: {pyroapi.get_backend()}")

    # Now you can use the generic API for distributions, parameters, etc.
    import pyroapi.distributions as dist
    import pyroapi.params as params

    # Example: Define a generic parameter and a distribution
    mean_param = params.param("mean_value", 0.0)
    std_param = params.param("std_value", 1.0, constraint=dist.constraints.positive)

    generic_normal = dist.Normal(mean_param, std_param)
    print(f"\nCreated a generic Normal distribution with mean {mean_param.data} and std {std_param.data}")

except ImportError as e:
    print(f"Failed to set Pyro-API backend: {e}")
    print("Please ensure your chosen backend (e.g., 'pyro' or 'numpyro') is installed.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

view raw JSON →