PyStan

3.10.1 · active · verified Mon Apr 13

PyStan is a Python interface to Stan, a powerful platform for Bayesian inference and high-performance statistical computation. It allows users to define statistical models using Stan's probabilistic programming language and fit them using Hamiltonian Monte Carlo (HMC) methods. Currently at version 3.10.1, PyStan focuses on providing a reliable HMC sampler, with a development cadence that sees frequent updates to minor versions.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to define a Stan model, provide data, build the model, sample from the posterior distribution using HMC, and extract results for analysis. It uses the classic 'Eight Schools' hierarchical model. The `stan.build()` step compiles the model, which can take some time. `posterior.sample()` then runs the MCMC chains.

import stan
import numpy as np

schools_code = """
data {
  int<lower=0> J; // number of schools
  array[J] real y; // estimated treatment effects
  array[J] real<lower=0> sigma; // standard error of effect estimates
}
parameters {
  real mu; // population treatment effect
  real<lower=0> tau; // standard deviation in treatment effects
  vector[J] eta; // unscaled deviation from mu by school
}
transformed parameters {
  vector[J] theta = mu + tau * eta; // school treatment effects
}
model {
  target += normal_lpdf(eta | 0, 1); // prior log-density
  target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""

schools_data = {
    "J": 8,
    "y": [28, 8, -3, 7, -1, 1, 18, 12],
    "sigma": [15, 10, 16, 11, 9, 11, 10, 18],
}

# Build the model (compiles Stan code to C++ and then to executable)
posterior = stan.build(schools_code, data=schools_data, random_seed=1)

# Sample from the posterior distribution
fit = posterior.sample(num_chains=4, num_samples=1000)

# Extract samples for a parameter
mu_samples = fit["mu"]
print(f"Mean of mu samples: {np.mean(mu_samples)}")

# To get a pandas DataFrame (requires pandas installed):
# import pandas as pd
# df = fit.to_frame()
# print(df.head())

view raw JSON →