PySpark HNSW Library
pyspark-hnsw is a Python library that provides a distributed implementation of Hierarchical Navigable Small Worlds (HNSW) for Approximate Nearest Neighbor (ANN) search on Apache Spark. It enables efficient vector similarity search on large datasets within a PySpark environment, leveraging Spark's distributed processing capabilities. The current stable version available on PyPI is 1.1.0, with a moderate release cadence, including minor updates in recent months.
Warnings
- gotcha There is a discrepancy between the latest PyPI version (1.1.0) and the latest GitHub release (1.2.1). Ensure you are aware of which version you are installing and its associated features/fixes.
- gotcha Building and querying HNSW indices, especially with high dimensionality or large datasets, can be memory and CPU intensive. Adjust Spark executor memory (`spark.executor.memory`), number of partitions (`setNumPartitions`), and HNSW parameters (`setM`, `setEf`) carefully.
- breaking Version 1.2.0 (not yet on PyPI as of 1.1.0) includes a repackaging of classes (`Repackage classes to avoid JPMS issues`). While primarily affecting Java Module System users, this change might alter internal class paths or dependencies that could indirectly impact complex PySpark setups or users relying on specific internal JAR references.
- gotcha The `index_path` used to build the index must be accessible and writable by all Spark executors, and it should typically point to a distributed file system like HDFS, S3, or similar. Using a local path will store the index only on the driver or the first executor, which is not suitable for distributed use.
- gotcha Ensure your vectors are in a format compatible with `pyspark-hnsw`, typically `ArrayType(FloatType)`. Mismatched data types can lead to errors during index building or querying.
Install
-
pip install pyspark-hnsw
Imports
- HnswIndex
from pyspark_hnsw import HnswIndex
Quickstart
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark_hnsw import HnswIndex
import numpy as np
import os
# Configure Spark (local mode for example)
conf = SparkConf().setAppName("HnswQuickstart").setMaster("local[*]")
sc = SparkContext(conf=conf)
spark = SparkSession(sc)
# Create some sample data with 128-dimensional vectors
data = [(i, [float(x) for x in np.random.rand(128)]) for i in range(1000)]
df = spark.createDataFrame(data, ["id", "vector"])
# Define a path for the index (local or distributed filesystem like HDFS/S3)
# Ensure this path is writable and accessible by Spark workers
index_path = "hnsw_index_test_dir"
# Clean up previous index if it exists for repeatable runs
if os.path.exists(index_path):
import shutil
shutil.rmtree(index_path)
# Build the HNSW index
hnsw_index = HnswIndex(spark, "id", "vector", index_path) \
.setM(16) \
.setEf(100) \
.setNumPartitions(10) \
.setDistanceType("cosine") \
.build(df)
# Define a query vector
query_vector = [float(x) for x in np.random.rand(128)]
num_neighbors = 5
# Find nearest neighbors
result = hnsw_index.findNearestNeighbors(query_vector, num_neighbors)
result.show()
# Stop the Spark session
spark.stop()