From d65181c1b08fb3249112ad008f01aff633d1900a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 16 Sep 2024 16:42:09 -0600 Subject: [PATCH] Avoid rechunking when preferred_method="blockwise" (#394) * Avoid rechunking when preferred_method="blockwise" * Add test * fix --- flox/core.py | 4 +++- tests/test_core.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index c419f746..7cab9b39 100644 --- a/flox/core.py +++ b/flox/core.py @@ -642,6 +642,7 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> DaskArray Rechunked array """ + # TODO: this should be unnecessary? labels = factorize_((labels,), axes=())[0] chunks = array.chunks[axis] newchunks = _get_optimal_chunks_for_groups(chunks, labels) @@ -2623,7 +2624,8 @@ def groupby_reduce( partial_agg = partial(dask_groupby_agg, **kwargs) - if method == "blockwise" and by_.ndim == 1: + # if preferred method is already blockwise, no need to rechunk + if preferred_method != "blockwise" and method == "blockwise" and by_.ndim == 1: array = rechunk_for_blockwise(array, axis=-1, labels=by_) result, groups = partial_agg( diff --git a/tests/test_core.py b/tests/test_core.py index cef9ad8a..94e32a6d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1997,3 +1997,12 @@ def test_agg_dtypes(func, engine): ) expected = _get_array_func(func)(counts, dtype="uint8") assert actual.dtype == np.uint8 == expected.dtype + + +@requires_dask +def test_blockwise_avoid_rechunk(): + array = dask.array.zeros((6,), chunks=(2, 4), dtype=np.int64) + by = np.array(["1", "1", "0", "", "0", ""], dtype="