From c1b5c1c6406e7cc76cfed38d221c458c654f29a7 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 14 Sep 2024 19:33:21 -0600 Subject: [PATCH 1/3] Avoid rechunking when preferred_method="blockwise" --- flox/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index c419f746..7cab9b39 100644 --- a/flox/core.py +++ b/flox/core.py @@ -642,6 +642,7 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> DaskArray Rechunked array """ + # TODO: this should be unnecessary? labels = factorize_((labels,), axes=())[0] chunks = array.chunks[axis] newchunks = _get_optimal_chunks_for_groups(chunks, labels) @@ -2623,7 +2624,8 @@ def groupby_reduce( partial_agg = partial(dask_groupby_agg, **kwargs) - if method == "blockwise" and by_.ndim == 1: + # if preferred method is already blockwise, no need to rechunk + if preferred_method != "blockwise" and method == "blockwise" and by_.ndim == 1: array = rechunk_for_blockwise(array, axis=-1, labels=by_) result, groups = partial_agg( From 40efff208f2c0b435438f1e31ec9dc6dcc8b4fc9 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 16 Sep 2024 14:32:40 -0600 Subject: [PATCH 2/3] Add test --- tests/test_core.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index cef9ad8a..4ead7bfb 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1997,3 +1997,11 @@ def test_agg_dtypes(func, engine): ) expected = _get_array_func(func)(counts, dtype="uint8") assert actual.dtype == np.uint8 == expected.dtype + + +def test_blockwise_avoid_rechunk(): + array = dask.array.zeros((6,), chunks=(2, 4), dtype=np.int64) + by = np.array(["1", "1", "0", "", "0", ""], dtype=" Date: Mon, 16 Sep 2024 14:50:59 -0600 Subject: [PATCH 3/3] fix --- tests/test_core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_core.py b/tests/test_core.py index 4ead7bfb..94e32a6d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1999,6 +1999,7 @@ def test_agg_dtypes(func, engine): assert actual.dtype == np.uint8 == expected.dtype +@requires_dask def test_blockwise_avoid_rechunk(): array = dask.array.zeros((6,), chunks=(2, 4), dtype=np.int64) by = np.array(["1", "1", "0", "", "0", ""], dtype="