PySpark HNSW Library

1.1.0 · active · verified Sat Apr 11

pyspark-hnsw is a Python library that provides a distributed implementation of Hierarchical Navigable Small Worlds (HNSW) for Approximate Nearest Neighbor (ANN) search on Apache Spark. It enables efficient vector similarity search on large datasets within a PySpark environment, leveraging Spark's distributed processing capabilities. The current stable version available on PyPI is 1.1.0, with a moderate release cadence, including minor updates in recent months.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize a Spark session, create a sample DataFrame with vector data, build an HNSW index using `HnswIndex`, and then perform a nearest neighbor search. Remember to configure Spark appropriately for your environment (e.g., local, YARN, Kubernetes) and ensure the `index_path` is accessible by all Spark workers.

from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark_hnsw import HnswIndex
import numpy as np
import os

# Configure Spark (local mode for example)
conf = SparkConf().setAppName("HnswQuickstart").setMaster("local[*]")
sc = SparkContext(conf=conf)
spark = SparkSession(sc)

# Create some sample data with 128-dimensional vectors
data = [(i, [float(x) for x in np.random.rand(128)]) for i in range(1000)]
df = spark.createDataFrame(data, ["id", "vector"])

# Define a path for the index (local or distributed filesystem like HDFS/S3)
# Ensure this path is writable and accessible by Spark workers
index_path = "hnsw_index_test_dir"

# Clean up previous index if it exists for repeatable runs
if os.path.exists(index_path):
    import shutil
    shutil.rmtree(index_path)

# Build the HNSW index
hnsw_index = HnswIndex(spark, "id", "vector", index_path) \
    .setM(16) \
    .setEf(100) \
    .setNumPartitions(10) \
    .setDistanceType("cosine") \
    .build(df)

# Define a query vector
query_vector = [float(x) for x in np.random.rand(128)]
num_neighbors = 5

# Find nearest neighbors
result = hnsw_index.findNearestNeighbors(query_vector, num_neighbors)
result.show()

# Stop the Spark session
spark.stop()

view raw JSON →