SHAP (SHapley Additive exPlanations)

0.51.0 · active · verified Mon Apr 06

SHAP is a Python library that provides a unified approach to explain the output of any machine learning model using Shapley values, a concept from game theory. It offers various 'explainers' optimized for different model types (e.g., tree models, deep learning models, or model-agnostic approaches) and powerful visualization tools. Currently at version 0.51.0, it maintains an active development cycle with minor releases typically occurring every month or two.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use SHAP to explain an XGBoost classifier. It involves generating synthetic data, training a model, initializing the SHAP explainer (which automatically detects the model type), calculating SHAP values using the modern callable explainer API, and visualizing the feature impacts with a beeswarm plot. For interactive environments like Jupyter, `shap.initjs()` may be required for plot rendering.

import shap
import xgboost
import pandas as pd
from sklearn.datasets import make_classification

# Generate synthetic data
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, n_redundant=2, n_classes=2, random_state=42)
X = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(X.shape[1])])

# Train an XGBoost model
model = xgboost.XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
model.fit(X, y)

# For Jupyter notebooks, uncomment the following line to enable JavaScript visualizations:
# shap.initjs()

# Create a SHAP Explainer (automatically infers TreeExplainer for XGBoost)
explainer = shap.Explainer(model, X)

# Calculate SHAP values for the dataset
shap_values = explainer(X)

# Visualize the global impact of features using a beeswarm plot
# This plot shows the distribution of SHAP values for each feature.
shap.plots.beeswarm(shap_values, max_display=10)

# To visualize an individual prediction (e.g., the first instance) with a waterfall plot:
# shap.plots.waterfall(shap_values[0])

view raw JSON →