Nutpie for Stan or PyMC Models
Nutpie is a Python library designed to sample Stan or PyMC models efficiently, leveraging JAX for high-performance computation. It provides an alternative MCMC sampler for probabilistic programming models, aiming for speed and robustness. The current version is 0.16.8, and it maintains a frequent release cadence, often with minor bug fixes, dependency updates, and feature enhancements.
Common errors
-
ModuleNotFoundError: No module named 'nutpie'
cause The 'nutpie' library is not installed in your Python environment.fixInstall it using pip: `pip install nutpie` -
RuntimeError: PyMC >= 5 is required but was not found.
cause Nutpie relies on PyMC version 5 or higher to define models. Either PyMC is not installed, or an incompatible older version is present.fixEnsure PyMC is installed and meets the version requirement: `pip install 'pymc>=5.0'` -
ValueError: Dimensions mismatch for variable 'your_variable_name'
cause This error typically indicates an inconsistency in the shape or dimensions of data within your PyMC model definition, which Nutpie's JAX compilation cannot resolve.fixCarefully review the `shape` and `dims` arguments in your PyMC model's variable definitions, ensuring they correctly align with your observed data and other priors/likelihoods. -
TypeError: 'numpy.ndarray' object has no attribute '__jax_array__'
cause This error can occur if you're attempting to pass unsupported Python objects or NumPy arrays with incompatible dtypes (e.g., object dtype) into a JAX-compiled context, which Nutpie uses.fixEnsure all inputs to your PyMC model (especially observed data) are standard numerical NumPy arrays (e.g., float32, int64) and avoid object arrays or other complex types that JAX cannot directly convert.
Warnings
- gotcha Older versions of Nutpie (pre-0.16.7) might encounter compatibility issues when sampling PyMC models that heavily utilize `pymc.dims`. Ensure you are on a recent version for full `pymc.dims` support.
- gotcha Nutpie versions prior to 0.16.5 had a bug affecting compatibility with pandas 3.0 for string coordinates. If you use pandas 3.0 or later with models containing string coordinates, you might encounter errors.
- gotcha The `mindepth` parameter when `check_turning=True` was misbehaving in versions before 0.16.3. This could lead to incorrect or inefficient sampling behavior under specific configurations.
- breaking Starting from v0.16.0, step size jitter is enabled by default during NUTS sampling. This changes the default behavior, potentially leading to slightly different sampling paths compared to previous versions where it was disabled by default.
Install
-
pip install nutpie
Imports
- sample
import nutpie; nutpie.sample(...)
- compile_pymc_model
import nutpie; nutpie.compile_pymc_model(...)
Quickstart
import pymc as pm
import nutpie as np
# Define a simple PyMC model
with pm.Model() as model:
# Priors
mu = pm.Normal('mu', mu=0, sigma=1)
sigma = pm.HalfNormal('sigma', sigma=1)
# Likelihood
obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=[1.0, 2.0, 3.0])
# Sample the model using Nutpie
print("Starting Nutpie sampling...")
idata = np.sample(model)
print("Sampling complete. InferenceData:\n", idata)