sklearn-evaluation

raw JSON →
0.12.2 verified Fri May 01 auth: no python

A Python library for evaluating scikit-learn models, providing a rich set of plots, tables, and markdown reports. Current version 0.12.2. Released irregularly; latest releases are minor patches.

pip install sklearn-evaluation
error ModuleNotFoundError: No module named 'sklearn-evaluation'
cause Trying to import the package with a hyphen instead of an underscore.
fix
Use 'import sklearn_evaluation' or 'from sklearn_evaluation import ...'
error AttributeError: module 'sklearn_evaluation' has no attribute 'plot'
cause Trying to use old functional API (sklearn_evaluation.plot.confusion_matrix) after version 0.5.
fix
Use the new class-based API: from sklearn_evaluation import ConfusionMatrix; ConfusionMatrix(...).plot()
gotcha Import from 'sklearn_evaluation' with underscores, not 'sklearn-evaluation' (hyphen). The package name on PyPI uses a hyphen, but the import uses an underscore.
fix Use 'from sklearn_evaluation import ...'
deprecated Version 0.5 introduced a new API for reports, deprecating the old 'sklearn_evaluation.plot.*' functions. Users should use the new object-oriented API (e.g., ClassificationReport, ConfusionMatrix) instead of the old functional interface.
fix Use classes like ClassificationReport, ConfusionMatrix, etc.
breaking In version 0.5, the report generation API changed. Old code using 'sklearn_evaluation.report' may break.
fix Migrate to the new API: from sklearn_evaluation import ClassificationReport

Train a Random Forest classifier and generate a classification report.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn_evaluation import ClassificationReport

X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = RandomForestClassifier(random_state=0)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

report = ClassificationReport(y_test, y_pred)
print(report)
report.plot()