From bf5809d92c9b764b594171abf9f8793dd95eba9b Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 23 Aug 2023 12:00:21 -0700 Subject: [PATCH 1/4] Fix type mismatch in empty reduction ops --- python/cudf/cudf/core/groupby/groupby.py | 11 +++++++++- python/cudf/cudf/tests/test_groupby.py | 27 ++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 2ed9bed5b49..4c2c656c7db 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -527,6 +527,12 @@ def agg(self, func): 1 1.5 1.75 2.0 2.0 2 3.0 3.00 1.0 1.0 """ + is_empty = self.obj.empty + op_name = func.__name__ if callable(func) else func + is_reduction = ( + isinstance(op_name, str) + and op_name in Reducible._SUPPORTED_REDUCTIONS + ) column_names, columns, normalized_aggs = self._normalize_aggs(func) orig_dtypes = tuple(c.dtype for c in columns) @@ -563,7 +569,10 @@ def agg(self, func): ): # Structs lose their labels which we reconstruct here col = col._with_type_metadata(cudf.ListDtype(orig_dtype)) - data[key] = col + if is_empty and is_reduction and len(col) == 0: + data[key] = col.astype(orig_dtype) + else: + data[key] = col data = ColumnAccessor(data, multiindex=multilevel) if not multilevel: data = data.rename_levels({np.nan: None}, level=0) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index e578e1061ca..47a8d491027 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -3340,3 +3340,30 @@ def test_group_by_pandas_sort_order(groups, sort): pdf.groupby(groups, sort=sort).sum(), df.groupby(groups, sort=sort).sum(), ) + + +@pytest.mark.parametrize( + "dtype", + ["int32", "int64", "float64", "datetime64[ns]", "timedelta64[ns]", "bool"], +) +@pytest.mark.parametrize( + "reduce_op", + [ + "min", + "max", + "idxmin", + "idxmax", + "first", + "last", + ], +) +def test_group_by_empty_reduction(dtype, reduce_op): + gdf = cudf.DataFrame({"a": [], "b": [], "c": []}, dtype=dtype) + pdf = gdf.to_pandas() + + gg = gdf.groupby("a")["c"] + pg = pdf.groupby("a")["c"] + + assert_eq( + getattr(gg, reduce_op)(), getattr(pg, reduce_op)(), check_dtype=True + ) From b98be8bcef3b16ab3d864b184835d1647c4fa817 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 23 Aug 2023 12:57:11 -0700 Subject: [PATCH 2/4] Move computations closer --- python/cudf/cudf/core/groupby/groupby.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 4c2c656c7db..6196147b9a1 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -527,12 +527,6 @@ def agg(self, func): 1 1.5 1.75 2.0 2.0 2 3.0 3.00 1.0 1.0 """ - is_empty = self.obj.empty - op_name = func.__name__ if callable(func) else func - is_reduction = ( - isinstance(op_name, str) - and op_name in Reducible._SUPPORTED_REDUCTIONS - ) column_names, columns, normalized_aggs = self._normalize_aggs(func) orig_dtypes = tuple(c.dtype for c in columns) @@ -558,8 +552,8 @@ def agg(self, func): orig_dtypes, ): for agg, col in zip(aggs, cols): + agg_name = agg.__name__ if callable(agg) else agg if multilevel: - agg_name = agg.__name__ if callable(agg) else agg key = (col_name, agg_name) else: key = col_name @@ -569,7 +563,15 @@ def agg(self, func): ): # Structs lose their labels which we reconstruct here col = col._with_type_metadata(cudf.ListDtype(orig_dtype)) - if is_empty and is_reduction and len(col) == 0: + + if ( + self.obj.empty + and ( + isinstance(agg_name, str) + and agg_name in Reducible._SUPPORTED_REDUCTIONS + ) + and len(col) == 0 + ): data[key] = col.astype(orig_dtype) else: data[key] = col From 6b3584c3c46134a98c4069e2559f2431624caec2 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 23 Aug 2023 13:16:02 -0700 Subject: [PATCH 3/4] fix --- python/cudf/cudf/core/groupby/groupby.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 6196147b9a1..163ba4fc4a3 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -571,6 +571,13 @@ def agg(self, func): and agg_name in Reducible._SUPPORTED_REDUCTIONS ) and len(col) == 0 + and not isinstance( + col, + ( + cudf.core.column.ListColumn, + cudf.core.column.StructColumn, + ), + ) ): data[key] = col.astype(orig_dtype) else: From 96fa8656d51c9f016d5f7c319747abaff890c064 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 23 Aug 2023 13:17:07 -0700 Subject: [PATCH 4/4] fix --- python/cudf/cudf/core/groupby/groupby.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 163ba4fc4a3..3efe2e34f84 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -576,6 +576,7 @@ def agg(self, func): ( cudf.core.column.ListColumn, cudf.core.column.StructColumn, + cudf.core.column.DecimalBaseColumn, ), ) ):