diff --git a/flox/core.py b/flox/core.py index 1c9599a1..e082c7e2 100644 --- a/flox/core.py +++ b/flox/core.py @@ -170,7 +170,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool: def _is_first_last_reduction(func: T_Agg) -> bool: - return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"] + if isinstance(func, Aggregation): + func = func.name + return func in ["nanfirst", "nanlast", "first", "last"] def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex: @@ -1642,7 +1644,12 @@ def dask_groupby_agg( # This allows us to discover groups at compute time, support argreductions, lower intermediate # memory usage (but method="cohorts" would also work to reduce memory in some cases) labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None - do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown + do_grouped_combine = ( + _is_arg_reduction(agg) + or labels_are_unknown + or (_is_first_last_reduction(agg) and array.dtype.kind != "f") + ) + do_simple_combine = not do_grouped_combine if method == "blockwise": # use the "non dask" code path, but applied blockwise @@ -1986,8 +1993,13 @@ def _validate_reindex( expected_groups, any_by_dask: bool, is_dask_array: bool, + array_dtype: Any, ) -> bool | None: # logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa + def first_or_last(): + return func in ["first", "last"] or ( + _is_first_last_reduction(func) and array_dtype.kind != "f" + ) all_numpy = not is_dask_array and not any_by_dask if reindex is True and not all_numpy: @@ -1997,7 +2009,7 @@ def _validate_reindex( raise ValueError( "reindex=True is not a valid choice for method='blockwise' or method='cohorts'." ) - if func in ["first", "last"]: + if first_or_last(): raise ValueError("reindex must be None or False when func is 'first' or 'last.") if reindex is None: @@ -2008,9 +2020,10 @@ def _validate_reindex( if all_numpy: return True - if func in ["first", "last"]: + if first_or_last(): # have to do the grouped_combine since there's no good fill_value - reindex = False + # Also needed for nanfirst, nanlast with no-NaN dtypes + return False if method == "blockwise": # for grouping by dask arrays, we set reindex=True @@ -2413,7 +2426,13 @@ def groupby_reduce( raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.") reindex = _validate_reindex( - reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array) + reindex, + func, + method, + expected_groups, + any_by_dask, + is_duck_dask_array(array), + array.dtype, ) if not is_duck_array(array): @@ -2601,7 +2620,7 @@ def groupby_reduce( # TODO: clean this up reindex = _validate_reindex( - reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array) + reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype ) if TYPE_CHECKING: diff --git a/tests/test_core.py b/tests/test_core.py index e12e695d..22864a05 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -613,6 +613,33 @@ def test_dask_reduce_axis_subset(): ) +@pytest.mark.parametrize("group_idx", [[0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]]) +@pytest.mark.parametrize( + "func", + [ + # "first", "last", + "nanfirst", + "nanlast", + ], +) +@pytest.mark.parametrize( + "chunks", + [ + None, + pytest.param(1, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + pytest.param(2, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + pytest.param(3, marks=pytest.mark.skipif(not has_dask, reason="no dask")), + ], +) +def test_first_last_useless(func, chunks, group_idx): + array = np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int8) + if chunks is not None: + array = dask.array.from_array(array, chunks=chunks) + actual, _ = groupby_reduce(array, np.array(group_idx), func=func, engine="numpy") + expected = np.array([[0, 0], [0, 0]], dtype=np.int8) + assert_equal(actual, expected) + + @pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"]) @pytest.mark.parametrize("axis", [(0, 1)]) def test_first_last_disallowed(axis, func):