From 8ff27ed5bcaf8fc5fc8d1f546dee30c59861c320 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 19 Jul 2024 15:15:20 +0100 Subject: [PATCH] Support Literals in groupby-agg (#16218) To do this, we just need to collect the appropriate aggregation information, and broadcast literals to the correct size. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/16218 --- python/cudf_polars/cudf_polars/dsl/expr.py | 15 +++++++++++++++ python/cudf_polars/cudf_polars/dsl/ir.py | 4 ++-- python/cudf_polars/tests/test_groupby.py | 17 +++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index f37cb3f475c..a034d55120a 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -370,6 +370,10 @@ def do_evaluate( # datatype of pyarrow scalar is correct by construction. return Column(plc.Column.from_scalar(plc.interop.from_arrow(self.value), 1)) + def collect_agg(self, *, depth: int) -> AggInfo: + """Collect information about aggregations in groupbys.""" + return AggInfo([]) + class LiteralColumn(Expr): __slots__ = ("value",) @@ -382,6 +386,13 @@ def __init__(self, dtype: plc.DataType, value: pl.Series) -> None: data = value.to_arrow() self.value = data.cast(dtypes.downcast_arrow_lists(data.type)) + def get_hash(self) -> int: + """Compute a hash of the column.""" + # This is stricter than necessary, but we only need this hash + # for identity in groupby replacements so it's OK. And this + # way we avoid doing potentially expensive compute. + return hash((type(self), self.dtype, id(self.value))) + def do_evaluate( self, df: DataFrame, @@ -393,6 +404,10 @@ def do_evaluate( # datatype of pyarrow array is correct by construction. return Column(plc.interop.from_arrow(self.value)) + def collect_agg(self, *, depth: int) -> AggInfo: + """Collect information about aggregations in groupbys.""" + return AggInfo([]) + class Col(Expr): __slots__ = ("name",) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index cce0c4a3d94..01834ab75a5 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -514,7 +514,7 @@ def check_agg(agg: expr.Expr) -> int: return max(GroupBy.check_agg(child) for child in agg.children) elif isinstance(agg, expr.Agg): return 1 + max(GroupBy.check_agg(child) for child in agg.children) - elif isinstance(agg, (expr.Len, expr.Col, expr.Literal)): + elif isinstance(agg, (expr.Len, expr.Col, expr.Literal, expr.LiteralColumn)): return 0 else: raise NotImplementedError(f"No handler for {agg=}") @@ -574,7 +574,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: results = [ req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests ] - return DataFrame([*result_keys, *results]).slice(self.options.slice) + return DataFrame(broadcast(*result_keys, *results)).slice(self.options.slice) @dataclasses.dataclass diff --git a/python/cudf_polars/tests/test_groupby.py b/python/cudf_polars/tests/test_groupby.py index b07d8e38217..b650fee5079 100644 --- a/python/cudf_polars/tests/test_groupby.py +++ b/python/cudf_polars/tests/test_groupby.py @@ -155,3 +155,20 @@ def test_groupby_nan_minmax_raises(op): q = df.group_by("key").agg(op(pl.col("value"))) assert_ir_translation_raises(q, NotImplementedError) + + +@pytest.mark.parametrize("key", [1, pl.col("key1")]) +@pytest.mark.parametrize( + "expr", + [ + pl.lit(1).alias("value"), + pl.lit([[4, 5, 6]]).alias("value"), + pl.col("float") * (1 - pl.col("int")), + [pl.lit(2).alias("value"), pl.col("float") * 2], + ], +) +def test_groupby_literal_in_agg(df, key, expr): + # check_row_order=False doesn't work for list aggregations + # so just sort by the group key + q = df.group_by(key).agg(expr).sort(key, maintain_order=True) + assert_gpu_result_equal(q)