sklearn-compat
sklearn-compat is a small Python package designed to help developers write scikit-learn compatible estimators that support multiple scikit-learn versions. It factors out common utilities used by third-party libraries to manage version differences and provide a stable API. As of version 0.1.5, it supports scikit-learn >= 1.2, with recent updates for scikit-learn 1.8 and 1.9. It follows a release cadence tied to new scikit-learn releases, aiming to support scikit-learn versions up to 2 years or about 4 versions.
Warnings
- breaking Directly importing internal utilities from `sklearn` for multi-version support can lead to breaking changes as scikit-learn's internal API is not stable. `sklearn-compat` exists to provide stable compatibility layers.
- gotcha When trying to support `scikit-learn >= 1.2`, parameter validation for estimators and functions changed. Not using `_fit_context` on estimator's `fit` methods or `validate_params` on functions can lead to inconsistent behavior or failures.
- gotcha Scikit-learn 1.8 introduced changes to internal utilities like `_check_targets`, which now outputs 4 parameters. If your custom code expects a different signature, it will break.
- gotcha In scikit-learn 1.5, many developer utilities were moved to dedicated modules. Importing them directly by their old paths will fail in newer versions.
- gotcha The `sklearn-compat` library offers a 'vendored' version in `src/sklearn_compat/_sklearn_compat.py` for those who prefer not to add a direct package dependency. Mixing the vendored version with an installed `sklearn-compat` package or using outdated vendored code can lead to conflicts or missed updates.
Install
-
pip install sklearn-compat
Imports
- _fit_context
from sklearn_compat.base import _fit_context
- validate_params
from sklearn_compat.utils import validate_params
- is_clusterer
from sklearn_compat.base import is_clusterer
- _check_targets
from sklearn_compat.utils import _check_targets
Quickstart
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn_compat.base import _fit_context
from sklearn.utils.validation import check_is_fitted
import numpy as np
class MyCompatibleClassifier(BaseEstimator, ClassifierMixin):
# The _fit_context decorator ensures proper parameter validation
# and handling consistent with scikit-learn's internal mechanisms
# across different versions (e.g., 1.2+).
@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y):
if not isinstance(X, np.ndarray):
X = np.asarray(X)
if not isinstance(y, np.ndarray):
y = np.asarray(y)
self.classes_ = np.unique(y)
self.n_features_in_ = X.shape[1]
self.is_fitted_ = True
return self
def predict(self, X):
check_is_fitted(self)
# A simple prediction logic for demonstration
return np.full(X.shape[0], self.classes_[0])
# Example usage of the compatible classifier
X_train = np.array([[1, 2], [3, 4], [5, 6]])
y_train = np.array([0, 1, 0])
clf = MyCompatibleClassifier()
clf.fit(X_train, y_train)
print(f"Fitted classes: {clf.classes_}")
print(f"Predicted for [[7, 8]]: {clf.predict(np.array([[7, 8]]))}")