Scikit-Plot

0.3.7 · maintenance · verified Tue Apr 14

Scikit-plot is an intuitive Python library (v0.3.7) that extends scikit-learn objects with easy-to-use plotting functionality. It aims to simplify the visualization of machine learning models and metrics, such as confusion matrices, ROC curves, and learning curves, with minimal boilerplate code. The library has had infrequent releases, with the latest stable version published in August 2018.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to train a RandomForestClassifier on the digits dataset and then use scikit-plot to visualize its normalized confusion matrix with a single function call. It shows the typical import pattern and the use of a common metric plotting function.

import matplotlib.pyplot as plt
import scikitplot as skplt
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_predict

# Load dataset
X, y = load_digits(return_X_y=True)

# Split data (though cross_val_predict handles internal splits)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

# Train a classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Get predictions
y_pred = clf.predict(X_test)

# Plot confusion matrix
skplt.metrics.plot_confusion_matrix(y_test, y_pred, normalize=True)
plt.title('Normalized Confusion Matrix')
plt.show()

view raw JSON →