Scikit-Learn API wrapper for Keras
Scikeras provides a Scikit-Learn compatible API wrapper for Keras models, allowing Keras deep learning models to be used seamlessly with Scikit-Learn's powerful tools like GridSearchCV, Pipelines, and cross-validation. The current version is 0.13.0, and it follows a somewhat regular release cadence, typically every few months, often coinciding with new Keras or TensorFlow releases.
Common errors
-
ModuleNotFoundError: No module named 'keras.wrappers.scikit_learn'
cause You are attempting to import from the old, deprecated Keras wrapper module which is not part of the `scikeras` library.fixUpdate your import statement to `from scikeras.wrappers import KerasClassifier` (or `KerasRegressor`). -
ValueError: Unknown model argument: 'build_fn'
cause You are using `build_fn` as an argument, which was used in the old `tf.keras.wrappers.scikit_learn` API. Scikeras uses `model`.fixReplace `build_fn=create_model` with `model=create_model` when instantiating `KerasClassifier` or `KerasRegressor`. -
ImportError: cannot import name 'KerasClassifier' from 'scikeras.wrappers' (...)
cause This usually happens when `scikeras` >= 0.13.0 is installed, but an older version of Keras (e.g., Keras 2.x) is present in the environment, causing a compatibility mismatch. Scikeras v0.13.0+ requires Keras 3.fixEnsure Keras 3 is installed: `pip install --upgrade keras`. If you need to use Keras 2, you must downgrade scikeras to a compatible version (e.g., `pip install 'scikeras<0.13.0'`). -
TypeError: __init__() got an unexpected keyword argument 'learn_rate'
cause When passing Keras optimizer arguments, they should be defined as keyword arguments in the `model` function or passed directly to the `KerasClassifier`/`KerasRegressor` constructor if they are generic optimizer arguments (like `optimizer__learning_rate`).fixIf `learn_rate` is a Keras optimizer argument (e.g., for Adam), pass it as `optimizer__learning_rate=0.01` to `KerasClassifier` or ensure it's handled within your `build_model` function.
Warnings
- breaking Scikeras v0.13.0 drops support for Keras 2.x, TensorFlow < 2.15.0, and older Scikit-Learn versions. It requires Keras >= 3.0.0 and Python >= 3.9.
- breaking Scikeras v0.11.0 dropped support for Python 3.7. Later versions require Python 3.9 or newer.
- gotcha Scikeras expects the `model` argument to be a callable (function) that returns a compiled Keras model, not an already instantiated `tf.keras.Model` object.
- gotcha TensorFlow Datasets (`tf.data.Dataset`) are not directly supported as inputs (X, y) for `fit()`, `predict()`, or `score()` methods. Inputs must be NumPy arrays or similar array-like structures.
Install
-
pip install scikeras keras -
pip install scikeras 'keras[tensorflow]'
Imports
- KerasClassifier
from keras.wrappers.scikit_learn import KerasClassifier
from scikeras.wrappers import KerasClassifier
- KerasRegressor
from tensorflow.keras.wrappers.scikit_learn import KerasRegressor
from scikeras.wrappers import KerasRegressor
Quickstart
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from scikeras.wrappers import KerasClassifier
# 1. Define a Keras model creation function
def build_classifier_model(meta):
# meta contains useful information like n_features_in_, n_outputs_
model = Sequential([
Dense(10, activation="relu", input_shape=(meta["n_features_in_"],)),
Dense(meta["n_outputs_"], activation="softmax")
])
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
return model
# 2. Generate some dummy data
X = np.random.rand(100, 10).astype(np.float32)
y = np.random.randint(0, 3, 100).astype(np.int32) # 3 classes
# 3. Create a KerasClassifier instance
keras_clf = KerasClassifier(
model=build_classifier_model,
epochs=10,
batch_size=32,
verbose=0 # Suppress verbose output for quickstart
)
# 4. Train the model using the Scikit-Learn API
keras_clf.fit(X, y)
# 5. Make predictions
predictions = keras_clf.predict(X[:5])
print(f"Predictions for first 5 samples: {predictions}")
# You can also evaluate using the Scikit-Learn .score() method
score = keras_clf.score(X, y)
print(f"Model accuracy: {score:.4f}")