Fast GroupBy operations for Dask Arrays
Flox is a Python library that provides strategies for fast GroupBy reductions with dask.array, significantly enhancing performance for operations like climatologies, resampling, and histogramming. It was formerly known as `dask_groupby` and integrates seamlessly with xarray to offer more performant GroupBy and Resampling operations.
Warnings
- breaking The library was previously known as `dask_groupby`. Code relying on the old package name or import paths will break.
- gotcha When `flox` (version >= 2022.06.0) is installed, Xarray will automatically use `flox` by default for its `.groupby`, `.groupby_bins`, and `.resample` operations. This implicit usage can change performance characteristics or expose underlying `flox` issues.
- gotcha Custom reductions specified using `Aggregation` instances might not be fully functional or have undefined behavior in certain scenarios.
- gotcha High memory usage can occur with `flox` aggregations in Dask, particularly when lower-level tasks (e.g., data loading) continue running while higher-level reduction tasks are uncomputed.
- gotcha For Dask arrays, `flox` uses heuristics (since v0.9.0) to choose the optimal parallel algorithm (`map-reduce`, `blockwise`, `cohorts`). While generally robust, specific data distributions or chunking patterns might benefit from explicitly setting the `method` parameter in `groupby_reduce` or `xarray_reduce`.
Install
-
pip install flox
Imports
- groupby_reduce
from flox import groupby_reduce
- xarray_reduce
from flox.xarray import xarray_reduce
Quickstart
import dask.array as da
from flox import groupby_reduce
import numpy as np
# Create a sample Dask array
data = da.random.random((1000, 10), chunks=(100, 10))
# Create a 'by' array for grouping (e.g., categories 0-9)
groups = np.random.randint(0, 10, size=1000)
# Perform a GroupBy reduction (e.g., mean)
result_mean, group_labels = groupby_reduce(
data, groups, func="mean", expected_groups=np.arange(10)
)
print("Grouped Means (first 5 groups):\n", result_mean.compute()[:5])
print("Group Labels:\n", group_labels)