Spark-sklearn: Scikit-learn on Spark

0.3.0 · abandoned · verified Fri Apr 17

spark-sklearn provides integration tools for running scikit-learn's GridSearchCV and RandomizedSearchCV on Apache Spark clusters. It leverages Spark for distributed computation of model training, allowing users to scale hyperparameter tuning. The library is currently at version 0.3.0, with its last release in 2017, and appears to be in an abandoned state with no active development or maintenance.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use spark-sklearn's GridSearchCV to perform hyperparameter tuning for a scikit-learn SVC model, distributing the computation across a Spark cluster (or locally). It covers SparkContext initialization, data preparation, defining the estimator and parameter grid, fitting the model, and retrieving results.

import os
from pyspark import SparkContext
from spark_sklearn import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Initialize SparkContext
# For local testing, 'local[*]' works. For a cluster, set SPARK_MASTER env var.
if os.environ.get('SPARK_MASTER') is None:
    os.environ['SPARK_MASTER'] = 'local[*]'

sc = None
try:
    sc = SparkContext(appName="SparkSklearnExample")

    # Generate some synthetic data
    X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Define the estimator and parameter grid
    estimator = SVC(gamma='auto', random_state=42)
    param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}

    # Use Spark-backed GridSearchCV
    clf = GridSearchCV(sc, estimator, param_grid, cv=3)
    clf.fit(X_train, y_train)

    print("Best parameters found:", clf.best_params_)
    print("Best cross-validation score:", clf.best_score_)
    print("Test set accuracy:", clf.score(X_test, y_test))

except Exception as e:
    print(f"An error occurred: {e}")
finally:
    if sc:
        sc.stop()

view raw JSON →