Generalized Additive Models in Python (pyGAM)
pyGAM is a Python library for building Generalized Additive Models (GAMs), emphasizing modularity and performance. It extends generalized linear models by allowing non-linear functions of features using penalized B-splines while maintaining additivity, making models both flexible and interpretable. The current stable version is 0.12.0. It is actively maintained with a focus on compatibility and contributions are welcome.
Warnings
- gotcha Installing NumPy/SciPy linked to Intel MKL for acceleration can be tricky, especially with Conda due to channel compatibility issues. Pip's NumPy-MKL is often outdated. Consider third-party builds or specific Conda channels if MKL optimization is critical.
- gotcha P-values derived from models where smoothing parameters have been estimated may be lower than they should be, leading to an increased rate of false positives (rejecting the null hypothesis too readily).
- gotcha When combining a spline term (e.g., `s(feature)`) and a linear term for the *same* feature in a GAM, it can introduce a model identifiability problem. This can cause p-values to appear statistically significant when they are not.
- gotcha The `pyGAM` package available on `conda-forge` is typically less up-to-date than the version available via `pip`.
- gotcha For large models with constraints, installing `scikit-sparse` can significantly improve optimization performance due to its faster sparse Cholesky factorization. It also has a dependency on `nose` for its import.
- gotcha pyGAM is officially tested with Python 3.10+ and is compatible up to Python 3.13. Using it with significantly older or very new, untested Python versions might lead to unexpected compatibility issues.
Install
-
pip install pygam
Imports
- LinearGAM
from pygam import LinearGAM
- LogisticGAM
from pygam import LogisticGAM
- PoissonGAM
from pygam import PoissonGAM
- s
from pygam import s
- f
from pygam import f
- te
from pygam import te
Quickstart
import numpy as np
from pygam import LinearGAM, s, f
from pygam.datasets import wage
X, y = wage() # Load example data
# Define a GAM with a spline term for features 0 and 1, and a factor term for feature 2
gam = LinearGAM(s(0) + s(1) + f(2))
# Fit the model
gam.fit(X, y)
# Print a summary of the model fit
print(gam.summary())
# Example of predicting (using dummy data for simplicity)
dummy_X = np.array([[10, 20, 1], [15, 25, 0]])
predictions = gam.predict(dummy_X)
print(f"Predictions for dummy data: {predictions}")