sklearn-compat

0.1.5 · active · verified Sat Apr 11

sklearn-compat is a small Python package designed to help developers write scikit-learn compatible estimators that support multiple scikit-learn versions. It factors out common utilities used by third-party libraries to manage version differences and provide a stable API. As of version 0.1.5, it supports scikit-learn >= 1.2, with recent updates for scikit-learn 1.8 and 1.9. It follows a release cadence tied to new scikit-learn releases, aiming to support scikit-learn versions up to 2 years or about 4 versions.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to create a scikit-learn compatible estimator using `sklearn-compat`'s `_fit_context` decorator. This decorator helps developers ensure their custom estimators correctly handle parameter validation and `fit` method behavior across different scikit-learn versions, particularly those after 1.2.

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn_compat.base import _fit_context
from sklearn.utils.validation import check_is_fitted
import numpy as np

class MyCompatibleClassifier(BaseEstimator, ClassifierMixin):
    # The _fit_context decorator ensures proper parameter validation
    # and handling consistent with scikit-learn's internal mechanisms
    # across different versions (e.g., 1.2+).
    @_fit_context(prefer_skip_nested_validation=True)
    def fit(self, X, y):
        if not isinstance(X, np.ndarray):
            X = np.asarray(X)
        if not isinstance(y, np.ndarray):
            y = np.asarray(y)

        self.classes_ = np.unique(y)
        self.n_features_in_ = X.shape[1]
        self.is_fitted_ = True
        return self

    def predict(self, X):
        check_is_fitted(self)
        # A simple prediction logic for demonstration
        return np.full(X.shape[0], self.classes_[0])

# Example usage of the compatible classifier
X_train = np.array([[1, 2], [3, 4], [5, 6]])
y_train = np.array([0, 1, 0])

clf = MyCompatibleClassifier()
clf.fit(X_train, y_train)
print(f"Fitted classes: {clf.classes_}")
print(f"Predicted for [[7, 8]]: {clf.predict(np.array([[7, 8]]))}")

view raw JSON →