From d6ade218cc93bae3ce2344a5c8ab4a03b777c176 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Tue, 21 May 2024 21:17:46 +0000 Subject: [PATCH 1/4] Deprecate collect --- python/cudf/cudf/core/groupby/groupby.py | 8 +++++++- python/dask_cudf/dask_cudf/groupby.py | 4 ++-- python/dask_cudf/dask_cudf/tests/test_groupby.py | 4 +++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 3e4b8192888..dd64573efab 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -2180,7 +2180,13 @@ def func(x): @_cudf_nvtx_annotate def collect(self): """Get a list of all the values for each column in each group.""" - return self.agg("collect") + warnings.warn( + "Groupby.collect is deprecated and " + "will be removed in a future version. " + "Use `.agg(list)` instead.", + FutureWarning, + ) + return self.agg(list) @_cudf_nvtx_annotate def unique(self): diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 43ad4f0fee3..902eb99b1f1 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -166,7 +166,7 @@ def max(self, split_every=None, split_out=1): def collect(self, split_every=None, split_out=1): return _make_groupby_agg_call( self, - self._make_groupby_method_aggs("collect"), + self._make_groupby_method_aggs(list), split_every, split_out, ) @@ -310,7 +310,7 @@ def max(self, split_every=None, split_out=1): def collect(self, split_every=None, split_out=1): return _make_groupby_agg_call( self, - {self._slice: "collect"}, + {self._slice: list}, split_every, split_out, )[self._slice] diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index dc279bfa690..9efe3ba082c 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -9,6 +9,7 @@ from dask.utils_test import hlg_layer import cudf +from cudf.testing._utils import expect_warning_if import dask_cudf from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized @@ -62,7 +63,8 @@ def test_groupby_basic(series, aggregation, pdf): check_dtype = aggregation != "count" - expect = getattr(gdf_grouped, aggregation)() + with expect_warning_if(aggregation == "collect"): + expect = getattr(gdf_grouped, aggregation)() actual = getattr(ddf_grouped, aggregation)() if not QUERY_PLANNING_ON: From 98499a83b5033ba082ac74b8f9570630dec9efc3 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 22 May 2024 00:00:02 +0000 Subject: [PATCH 2/4] Update codepaths to not use collect --- python/dask_cudf/dask_cudf/expr/_groupby.py | 8 ++++---- python/dask_cudf/dask_cudf/groupby.py | 2 +- python/dask_cudf/dask_cudf/tests/test_groupby.py | 6 ++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/dask_cudf/dask_cudf/expr/_groupby.py b/python/dask_cudf/dask_cudf/expr/_groupby.py index 116893891e3..8bfa7dc94d0 100644 --- a/python/dask_cudf/dask_cudf/expr/_groupby.py +++ b/python/dask_cudf/dask_cudf/expr/_groupby.py @@ -17,11 +17,11 @@ class Collect(SingleAggregation): @staticmethod def groupby_chunk(arg): - return arg.agg("collect") + return arg.agg(list) @staticmethod def groupby_aggregate(arg): - gb = arg.agg("collect") + gb = arg.agg(list) if gb.ndim > 1: for col in gb.columns: gb[col] = gb[col].list.concat() @@ -31,7 +31,7 @@ def groupby_aggregate(arg): collect_aggregation = Aggregation( - name="collect", + name="list", chunk=Collect.groupby_chunk, agg=Collect.groupby_aggregate, ) @@ -41,7 +41,7 @@ def _translate_arg(arg): # Helper function to translate args so that # they can be processed correctly by upstream # dask & dask-expr. Right now, the only necessary - # translation is "collect" aggregations. + # translation is "list" aggregations. if isinstance(arg, dict): return {k: _translate_arg(v) for k, v in arg.items()} elif isinstance(arg, list): diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 902eb99b1f1..3126f1c97a2 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -28,7 +28,7 @@ "sum", "min", "max", - "collect", + list, "first", "last", ) diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 9efe3ba082c..527928022f8 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -9,7 +9,6 @@ from dask.utils_test import hlg_layer import cudf -from cudf.testing._utils import expect_warning_if import dask_cudf from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized @@ -63,9 +62,8 @@ def test_groupby_basic(series, aggregation, pdf): check_dtype = aggregation != "count" - with expect_warning_if(aggregation == "collect"): - expect = getattr(gdf_grouped, aggregation)() - actual = getattr(ddf_grouped, aggregation)() + expect = gdf_grouped.agg(aggregation) + actual = ddf_grouped.agg(aggregation) if not QUERY_PLANNING_ON: assert_cudf_groupby_layers(actual) From 59b605be6d381b7df47be051236021c85d5cd06e Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 22 May 2024 07:13:49 -0700 Subject: [PATCH 3/4] get all tests working --- python/cudf/cudf/core/groupby/groupby.py | 16 +++++++----- python/dask_cudf/dask_cudf/expr/_groupby.py | 26 +++++++++++-------- python/dask_cudf/dask_cudf/groupby.py | 19 +++++++++----- .../dask_cudf/dask_cudf/tests/test_groupby.py | 10 +++++-- 4 files changed, 45 insertions(+), 26 deletions(-) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index dd64573efab..bf24864c29d 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -40,6 +40,15 @@ from cudf.utils.utils import GetAttrGetItemMixin +def _deprecate_collect(): + warnings.warn( + "Groupby.collect is deprecated and " + "will be removed in a future version. " + "Use `.agg(list)` instead.", + FutureWarning, + ) + + # The three functions below return the quantiles [25%, 50%, 75%] # respectively, which are called in the describe() method to output # the summary stats of a GroupBy object @@ -2180,12 +2189,7 @@ def func(x): @_cudf_nvtx_annotate def collect(self): """Get a list of all the values for each column in each group.""" - warnings.warn( - "Groupby.collect is deprecated and " - "will be removed in a future version. " - "Use `.agg(list)` instead.", - FutureWarning, - ) + _deprecate_collect() return self.agg(list) @_cudf_nvtx_annotate diff --git a/python/dask_cudf/dask_cudf/expr/_groupby.py b/python/dask_cudf/dask_cudf/expr/_groupby.py index 116893891e3..65688115b59 100644 --- a/python/dask_cudf/dask_cudf/expr/_groupby.py +++ b/python/dask_cudf/dask_cudf/expr/_groupby.py @@ -9,19 +9,21 @@ from dask.dataframe.groupby import Aggregation +from cudf.core.groupby.groupby import _deprecate_collect + ## ## Custom groupby classes ## -class Collect(SingleAggregation): +class ListAgg(SingleAggregation): @staticmethod def groupby_chunk(arg): - return arg.agg("collect") + return arg.agg(list) @staticmethod def groupby_aggregate(arg): - gb = arg.agg("collect") + gb = arg.agg(list) if gb.ndim > 1: for col in gb.columns: gb[col] = gb[col].list.concat() @@ -30,10 +32,10 @@ def groupby_aggregate(arg): return gb.list.concat() -collect_aggregation = Aggregation( - name="collect", - chunk=Collect.groupby_chunk, - agg=Collect.groupby_aggregate, +list_aggregation = Aggregation( + name="list", + chunk=ListAgg.groupby_chunk, + agg=ListAgg.groupby_aggregate, ) @@ -41,13 +43,13 @@ def _translate_arg(arg): # Helper function to translate args so that # they can be processed correctly by upstream # dask & dask-expr. Right now, the only necessary - # translation is "collect" aggregations. + # translation is list aggregations. if isinstance(arg, dict): return {k: _translate_arg(v) for k, v in arg.items()} elif isinstance(arg, list): return [_translate_arg(x) for x in arg] elif arg in ("collect", "list", list): - return collect_aggregation + return list_aggregation else: return arg @@ -84,7 +86,8 @@ def __getitem__(self, key): return g def collect(self, **kwargs): - return self._single_agg(Collect, **kwargs) + _deprecate_collect() + return self._single_agg(ListAgg, **kwargs) def aggregate(self, arg, **kwargs): return super().aggregate(_translate_arg(arg), **kwargs) @@ -96,7 +99,8 @@ def __init__(self, *args, observed=None, **kwargs): super().__init__(*args, observed=observed, **kwargs) def collect(self, **kwargs): - return self._single_agg(Collect, **kwargs) + _deprecate_collect() + return self._single_agg(ListAgg, **kwargs) def aggregate(self, arg, **kwargs): return super().aggregate(_translate_arg(arg), **kwargs) diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 902eb99b1f1..14d262bb38e 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -15,6 +15,7 @@ from dask.utils import funcname import cudf +from cudf.core.groupby.groupby import _deprecate_collect from cudf.utils.nvtx_annotation import _dask_cudf_nvtx_annotate from dask_cudf.sorting import _deprecate_shuffle_kwarg @@ -28,9 +29,9 @@ "sum", "min", "max", - "collect", "first", "last", + list, ) @@ -164,6 +165,7 @@ def max(self, split_every=None, split_out=1): @_dask_cudf_nvtx_annotate @_check_groupby_optimized def collect(self, split_every=None, split_out=1): + _deprecate_collect() return _make_groupby_agg_call( self, self._make_groupby_method_aggs(list), @@ -308,6 +310,7 @@ def max(self, split_every=None, split_out=1): @_dask_cudf_nvtx_annotate @_check_groupby_optimized def collect(self, split_every=None, split_out=1): + _deprecate_collect() return _make_groupby_agg_call( self, {self._slice: list}, @@ -472,7 +475,7 @@ def groupby_agg( This aggregation algorithm only supports the following options - * "collect" + * "list" * "count" * "first" * "last" @@ -667,8 +670,8 @@ def _redirect_aggs(arg): sum: "sum", max: "max", min: "min", - list: "collect", - "list": "collect", + "collect": list, + "list": list, } if isinstance(arg, dict): new_arg = dict() @@ -704,7 +707,7 @@ def _aggs_optimized(arg, supported: set): _global_set = set(arg) return bool(_global_set.issubset(supported)) - elif isinstance(arg, str): + elif isinstance(arg, (str, type)): return arg in supported return False @@ -783,6 +786,8 @@ def _tree_node_agg(df, gb_cols, dropna, sort, sep): agg = col.split(sep)[-1] if agg in ("count", "sum"): agg_dict[col] = ["sum"] + elif agg == "list": + agg_dict[col] = [list] elif agg in OPTIMIZED_AGGS: agg_dict[col] = [agg] else: @@ -873,8 +878,8 @@ def _finalize_gb_agg( gb.drop(columns=[sum_name], inplace=True) if "count" not in agg_list: gb.drop(columns=[count_name], inplace=True) - if "collect" in agg_list: - collect_name = _make_name((col, "collect"), sep=sep) + if list in agg_list: + collect_name = _make_name((col, "list"), sep=sep) gb[collect_name] = gb[collect_name].list.concat() # Ensure sorted keys if `sort=True` diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 9efe3ba082c..e0b83ddc213 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -48,7 +48,13 @@ def pdf(request): return pdf -@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS) +# NOTE: We only want to test aggregation "methods" here, +# so we need to leave out `list`. We also include a +# deprecation check for "collect". +@pytest.mark.parametrize( + "aggregation", + tuple(set(OPTIMIZED_AGGS) - {list}) + ("collect",), +) @pytest.mark.parametrize("series", [False, True]) def test_groupby_basic(series, aggregation, pdf): gdf = cudf.DataFrame.from_pandas(pdf) @@ -65,7 +71,7 @@ def test_groupby_basic(series, aggregation, pdf): with expect_warning_if(aggregation == "collect"): expect = getattr(gdf_grouped, aggregation)() - actual = getattr(ddf_grouped, aggregation)() + actual = getattr(ddf_grouped, aggregation)() if not QUERY_PLANNING_ON: assert_cudf_groupby_layers(actual) From 3e0b80fac068614511e798867312537b64083336 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 22 May 2024 14:38:19 +0000 Subject: [PATCH 4/4] sort --- python/dask_cudf/dask_cudf/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index e0b83ddc213..cf916b713b2 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -53,7 +53,7 @@ def pdf(request): # deprecation check for "collect". @pytest.mark.parametrize( "aggregation", - tuple(set(OPTIMIZED_AGGS) - {list}) + ("collect",), + sorted(tuple(set(OPTIMIZED_AGGS) - {list}) + ("collect",)), ) @pytest.mark.parametrize("series", [False, True]) def test_groupby_basic(series, aggregation, pdf):