SHAP (SHapley Additive exPlanations)
SHAP is a Python library that provides a unified approach to explain the output of any machine learning model using Shapley values, a concept from game theory. It offers various 'explainers' optimized for different model types (e.g., tree models, deep learning models, or model-agnostic approaches) and powerful visualization tools. Currently at version 0.51.0, it maintains an active development cycle with minor releases typically occurring every month or two.
Warnings
- breaking SHAP v0.50.0 and later versions officially dropped support for Python 3.9 and 3.10. The library now requires Python 3.11 or newer.
- breaking The SHAP API for calculating explanation values and plotting underwent a significant change around v0.36.0. The `explainer.shap_values(X)` method was replaced by making the `explainer` object directly callable (`explainer(X)`), returning a `shap.Explanation` object. Legacy plotting functions like `shap.summary_plot` were deprecated, and new plotting functions (e.g., `shap.plots.beeswarm`, `shap.plots.waterfall`) expect the `shap.Explanation` object. The `auto_size_plot` parameter was removed from `shap.summary_plot` in v0.46.0.
- gotcha While SHAP v0.46.0 added support for NumPy 2.0, some machine learning libraries that SHAP depends on (e.g., TensorFlow, SciPy, Scikit-learn, Pandas, Numba) might still have compatibility issues or explicit version pins that prevent them from working with NumPy 2.0. This can lead to dependency conflicts during installation or runtime errors if not managed carefully.
- gotcha For large datasets or high-dimensional inputs, `shap.KernelExplainer` can be very slow due to its model-agnostic nature, requiring numerous model evaluations. For tree-based models (like XGBoost, LightGBM, CatBoost), `shap.TreeExplainer` is significantly faster and provides exact SHAP values.
Install
-
pip install shap
Imports
- shap
import shap
- shap.Explainer
explainer = shap.Explainer(model, data)
- shap_values
shap_values = explainer(data)
- shap.plots.beeswarm
shap.plots.beeswarm(shap_values)
Quickstart
import shap
import xgboost
import pandas as pd
from sklearn.datasets import make_classification
# Generate synthetic data
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, n_redundant=2, n_classes=2, random_state=42)
X = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(X.shape[1])])
# Train an XGBoost model
model = xgboost.XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
model.fit(X, y)
# For Jupyter notebooks, uncomment the following line to enable JavaScript visualizations:
# shap.initjs()
# Create a SHAP Explainer (automatically infers TreeExplainer for XGBoost)
explainer = shap.Explainer(model, X)
# Calculate SHAP values for the dataset
shap_values = explainer(X)
# Visualize the global impact of features using a beeswarm plot
# This plot shows the distribution of SHAP values for each feature.
shap.plots.beeswarm(shap_values, max_display=10)
# To visualize an individual prediction (e.g., the first instance) with a waterfall plot:
# shap.plots.waterfall(shap_values[0])