Dask-CUDA

26.4.0 · active · verified Thu Apr 16

Dask-CUDA is a Python library providing utilities to facilitate interactions between Dask and NVIDIA CUDA-enabled GPUs. It extends `dask.distributed`'s `LocalCluster` and `Worker` to manage and deploy Dask workers efficiently on GPU systems. Key features include automatic instantiation of per-GPU workers, setting CPU affinity for optimal performance, and robust GPU memory management, including spilling to host memory. It is a core component of the RAPIDS suite for GPU-accelerated data science. The library maintains an active development status with regular releases, currently at version 26.4.0.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to set up a `LocalCUDACluster` and connect a `dask.distributed.Client`. It highlights common configurations like specifying visible GPUs, configuring RAPIDS Memory Manager (RMM) for memory pooling, and enabling cuDF spilling to prevent out-of-memory errors on large datasets. The use of an `if __name__ == "__main__":` block is crucial for standalone scripts.

import os
from dask_cuda import LocalCUDACluster
from dask.distributed import Client

if __name__ == "__main__":
    # Recommended to run inside an if __name__ == "__main__": block
    # Configure for 2 GPUs, 90% RMM pool size, and enable cuDF spilling
    cluster = LocalCUDACluster(
        CUDA_VISIBLE_DEVICES="0,1",  # Example: use devices 0 and 1
        rmm_pool_size=0.9,           # Use 90% of GPU memory as a pool
        enable_cudf_spill=True,      # Enable spilling to host memory if needed
        local_directory=os.environ.get('DASK_LOCAL_DIRECTORY', '/tmp/dask-cuda')
    )
    client = Client(cluster)

    print(f"Dask-CUDA cluster dashboard link: {client.dashboard_link}")
    # Your Dask-accelerated GPU computations go here
    # For example, with dask-cudf:
    # import dask.dataframe as dd
    # import cudf
    # dask.config.set({"dataframe.backend": "cudf"})
    # ddf = dd.read_csv("my_gpu_data.csv")
    # result = ddf.groupby("col").sum().compute()

    client.close()
    cluster.close()

view raw JSON →