diff --git a/python/cudf/cudf/core/resample.py b/python/cudf/cudf/core/resample.py index 2bed71ea751..57630e7d4a9 100644 --- a/python/cudf/cudf/core/resample.py +++ b/python/cudf/cudf/core/resample.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & +# SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & # AFFILIATES. All rights reserved. SPDX-License-Identifier: # Apache-2.0 # diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index a64aabe1a6b..22705c2b83b 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -245,6 +245,7 @@ def last(self, split_every=None, split_out=1): def aggregate(self, arg, split_every=None, split_out=1): if arg == "size": return self.size() + arg = _redirect_aggs(arg) if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS): @@ -431,6 +432,7 @@ def last(self, split_every=None, split_out=1): def aggregate(self, arg, split_every=None, split_out=1): if arg == "size": return self.size() + arg = _redirect_aggs(arg) if not isinstance(arg, dict): @@ -503,7 +505,7 @@ def groupby_agg( if isinstance(gb_cols, str): gb_cols = [gb_cols] columns = [c for c in ddf.columns if c not in gb_cols] - if isinstance(aggs, list): + if not isinstance(aggs, dict): aggs = {col: aggs for col in columns} # Assert if our output will have a MultiIndex; this will be the case if @@ -665,6 +667,8 @@ def _aggs_supported(arg, supported: set): _global_set = set(arg) return bool(_global_set.issubset(supported)) + elif isinstance(arg, str): + return arg in supported return False diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 5aa9cffb789..2b7f2bdae36 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -58,7 +58,10 @@ def test_groupby_basic(series, aggregation): "func", [ lambda df: df.groupby("x").agg({"y": "max"}), + lambda df: df.groupby("x").agg(["sum", "max"]), lambda df: df.groupby("x").y.agg(["sum", "max"]), + lambda df: df.groupby("x").agg("sum"), + lambda df: df.groupby("x").y.agg("sum"), ], ) def test_groupby_agg(func): @@ -663,11 +666,20 @@ def test_groupby_agg_redirect(aggregations): @pytest.mark.parametrize( - "arg", - [["not_supported"], {"a": "not_supported"}, {"a": ["not_supported"]}], + "arg,supported", + [ + ("sum", True), + (["sum"], True), + ({"a": "sum"}, True), + ({"a": ["sum"]}, True), + ("not_supported", False), + (["not_supported"], False), + ({"a": "not_supported"}, False), + ({"a": ["not_supported"]}, False), + ], ) -def test_is_supported(arg): - assert _aggs_supported(arg, {"supported"}) is False +def test_is_supported(arg, supported): + assert _aggs_supported(arg, SUPPORTED_AGGS) is supported def test_groupby_unique_lists():