From 4ad1e51e8cafe05ce754fc2928df71cc5a3ef2b0 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Thu, 12 May 2022 20:21:58 -0400 Subject: [PATCH] Allow string aggs for `dask_cudf.CudfDataFrameGroupBy.aggregate` (#10222) I noticed that `CudfDataFrameGroupBy.aggregate` doesn't actually support passing aggregations as strings, for example something like ```python import cudf import dask_cudf gdf = cudf.DataFrame({'id4': 4*list(range(6)), 'id5': 4*list(reversed(range(6))), 'v3': 6*list(range(4))}) gddf = dask_cudf.from_cudf(gdf, npartitions=5) gddf.groupby("id4").agg("mean") ``` Would actually end up using the upstream `aggregate` implementation. This is because: - `CudfDataFrameGroupBy.aggregate` does not convert string aggs to a dict before calling `_is_supported` on them - `_is_supported` only handles list / dict aggs, returning false otherwise I've resolved this by adding string support to `_is_supported`, and moving the conversion of aggs to the internal `groupby_agg`. It looks like this is exposing some failures for `first` and `last` groupby aggs, as tests that were originally using upstream Dask to compute these aggregations (I assume accidentally since these aggregations are listed as supported) are now using dask-cuDF and getting the wrong result. Authors: - Charles Blackmon-Luca (https://github.com/charlesbluca) Approvers: - Bradley Dice (https://github.com/bdice) - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/10222 --- python/cudf/cudf/core/resample.py | 2 +- python/dask_cudf/dask_cudf/groupby.py | 6 +++++- .../dask_cudf/dask_cudf/tests/test_groupby.py | 20 +++++++++++++++---- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/python/cudf/cudf/core/resample.py b/python/cudf/cudf/core/resample.py index 2bed71ea751..57630e7d4a9 100644 --- a/python/cudf/cudf/core/resample.py +++ b/python/cudf/cudf/core/resample.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & +# SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & # AFFILIATES. All rights reserved. SPDX-License-Identifier: # Apache-2.0 # diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index a64aabe1a6b..22705c2b83b 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -245,6 +245,7 @@ def last(self, split_every=None, split_out=1): def aggregate(self, arg, split_every=None, split_out=1): if arg == "size": return self.size() + arg = _redirect_aggs(arg) if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS): @@ -431,6 +432,7 @@ def last(self, split_every=None, split_out=1): def aggregate(self, arg, split_every=None, split_out=1): if arg == "size": return self.size() + arg = _redirect_aggs(arg) if not isinstance(arg, dict): @@ -503,7 +505,7 @@ def groupby_agg( if isinstance(gb_cols, str): gb_cols = [gb_cols] columns = [c for c in ddf.columns if c not in gb_cols] - if isinstance(aggs, list): + if not isinstance(aggs, dict): aggs = {col: aggs for col in columns} # Assert if our output will have a MultiIndex; this will be the case if @@ -665,6 +667,8 @@ def _aggs_supported(arg, supported: set): _global_set = set(arg) return bool(_global_set.issubset(supported)) + elif isinstance(arg, str): + return arg in supported return False diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 5aa9cffb789..2b7f2bdae36 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -58,7 +58,10 @@ def test_groupby_basic(series, aggregation): "func", [ lambda df: df.groupby("x").agg({"y": "max"}), + lambda df: df.groupby("x").agg(["sum", "max"]), lambda df: df.groupby("x").y.agg(["sum", "max"]), + lambda df: df.groupby("x").agg("sum"), + lambda df: df.groupby("x").y.agg("sum"), ], ) def test_groupby_agg(func): @@ -663,11 +666,20 @@ def test_groupby_agg_redirect(aggregations): @pytest.mark.parametrize( - "arg", - [["not_supported"], {"a": "not_supported"}, {"a": ["not_supported"]}], + "arg,supported", + [ + ("sum", True), + (["sum"], True), + ({"a": "sum"}, True), + ({"a": ["sum"]}, True), + ("not_supported", False), + (["not_supported"], False), + ({"a": "not_supported"}, False), + ({"a": ["not_supported"]}, False), + ], ) -def test_is_supported(arg): - assert _aggs_supported(arg, {"supported"}) is False +def test_is_supported(arg, supported): + assert _aggs_supported(arg, SUPPORTED_AGGS) is supported def test_groupby_unique_lists():