Joblib Apache Spark Backend

0.6.0 · active · verified Thu Apr 16

Joblibspark provides an Apache Spark backend for the popular joblib library, enabling the distribution of parallel tasks across an Apache Spark cluster. This allows scikit-learn and other joblib-dependent libraries to leverage the distributed computing capabilities of Spark. The current version is 0.6.0, released on April 7, 2025, and the project shows active development and maintenance.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to register the joblibspark backend and use it with both scikit-learn estimators and custom parallel functions. It assumes a SparkSession is available, either pre-configured in environments like Databricks or initialized locally.

from joblibspark import register_spark
from joblib import parallel_backend, Parallel, delayed
from pyspark.sql import SparkSession
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification

# Initialize SparkSession (if not already running, e.g., in a Databricks notebook)
# In a Databricks notebook, spark variable is usually pre-defined.
# For local testing, uncomment and run:
# spark = SparkSession.builder.appName("JoblibSparkTest").master("local[*]").getOrCreate()

# 1. Register the Spark backend
register_spark()

# 2. Example with scikit-learn (using a dummy model)
X, y = make_classification(n_samples=100, n_features=4, random_state=42)
model = RandomForestClassifier(n_estimators=10, random_state=42)

print("Fitting model using Spark backend...")
with parallel_backend('spark', n_jobs=-1):
    model.fit(X, y)
print("Model fitted successfully.")

# 3. Example with a custom parallel function
def process_item(item):
    return item * 2

items = list(range(10))
print(f"Processing items: {items}")
with parallel_backend('spark', n_jobs=-1):
    results = Parallel()(delayed(process_item)(i) for i in items)
print(f"Processed results: {results}")

# If SparkSession was created manually, stop it
# if 'spark' in locals() and isinstance(spark, SparkSession) and spark.sparkContext._jsc.sc().master().startswith('local'):
#     spark.stop()

view raw JSON →