Scikit-Learn API wrapper for Keras

0.13.0 · active · verified Thu Apr 16

Scikeras provides a Scikit-Learn compatible API wrapper for Keras models, allowing Keras deep learning models to be used seamlessly with Scikit-Learn's powerful tools like GridSearchCV, Pipelines, and cross-validation. The current version is 0.13.0, and it follows a somewhat regular release cadence, typically every few months, often coinciding with new Keras or TensorFlow releases.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to wrap a Keras model with `KerasClassifier` for use with Scikit-Learn's API. It shows model definition, data generation, training with `.fit()`, and prediction with `.predict()`.

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from scikeras.wrappers import KerasClassifier

# 1. Define a Keras model creation function
def build_classifier_model(meta):
    # meta contains useful information like n_features_in_, n_outputs_
    model = Sequential([
        Dense(10, activation="relu", input_shape=(meta["n_features_in_"],)),
        Dense(meta["n_outputs_"], activation="softmax")
    ])
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return model

# 2. Generate some dummy data
X = np.random.rand(100, 10).astype(np.float32)
y = np.random.randint(0, 3, 100).astype(np.int32) # 3 classes

# 3. Create a KerasClassifier instance
keras_clf = KerasClassifier(
    model=build_classifier_model,
    epochs=10,
    batch_size=32,
    verbose=0 # Suppress verbose output for quickstart
)

# 4. Train the model using the Scikit-Learn API
keras_clf.fit(X, y)

# 5. Make predictions
predictions = keras_clf.predict(X[:5])
print(f"Predictions for first 5 samples: {predictions}")

# You can also evaluate using the Scikit-Learn .score() method
score = keras_clf.score(X, y)
print(f"Model accuracy: {score:.4f}")

view raw JSON →