Scikit-Plot
Scikit-plot is an intuitive Python library (v0.3.7) that extends scikit-learn objects with easy-to-use plotting functionality. It aims to simplify the visualization of machine learning models and metrics, such as confusion matrices, ROC curves, and learning curves, with minimal boilerplate code. The library has had infrequent releases, with the latest stable version published in August 2018.
Warnings
- breaking The Factory API, a previous way to instantiate plots, was deprecated in v0.3.0 and announced for removal in v0.4.0. Direct function calls from specific modules should be used instead.
- deprecated Functions `plot_precision_recall_curve` and `plot_roc_curve` were deprecated in v0.3.5. They have been replaced by `plot_precision_recall` and `plot_roc`, respectively, which offer more control over plotted curves.
- deprecated The `scikitplot.plotters` module was deprecated in v0.3.0, and its functions were distributed to more specialized modules (e.g., `skplt.metrics`, `skplt.estimators`).
- gotcha The `spectral` colormap used internally was deprecated and changed to `nipy_spectral` in v0.3.2 to avoid Matplotlib warnings/errors.
- gotcha As of its latest release (v0.3.7, August 2018), `scikit-plot` was built for Python versions and `scikit-learn` versions prevalent at that time (e.g., Python 3.5-3.7). Using it with very recent Python or `scikit-learn` versions (e.g., Python 3.9+ or scikit-learn 1.0+) may lead to compatibility issues or require specific dependency pinning.
Install
-
pip install scikit-plot
Imports
- skplt.metrics.plot_confusion_matrix
import scikitplot as skplt skplt.metrics.plot_confusion_matrix(y_true, y_pred)
- skplt.metrics.plot_roc
import scikitplot as skplt skplt.metrics.plot_roc(y_true, y_probas)
- skplt.metrics.plot_precision_recall
import scikitplot as skplt skplt.metrics.plot_precision_recall(y_true, y_probas)
Quickstart
import matplotlib.pyplot as plt
import scikitplot as skplt
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_predict
# Load dataset
X, y = load_digits(return_X_y=True)
# Split data (though cross_val_predict handles internal splits)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
# Train a classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
# Get predictions
y_pred = clf.predict(X_test)
# Plot confusion matrix
skplt.metrics.plot_confusion_matrix(y_test, y_pred, normalize=True)
plt.title('Normalized Confusion Matrix')
plt.show()