Skip to content

Commit

Permalink
Support Literals in groupby-agg (#16218)
Browse files Browse the repository at this point in the history
To do this, we just need to collect the appropriate aggregation information, and broadcast literals to the correct size.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #16218
  • Loading branch information
wence- authored Jul 19, 2024
1 parent debbef0 commit 8ff27ed
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
15 changes: 15 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ def do_evaluate(
# datatype of pyarrow scalar is correct by construction.
return Column(plc.Column.from_scalar(plc.interop.from_arrow(self.value), 1))

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
return AggInfo([])


class LiteralColumn(Expr):
__slots__ = ("value",)
Expand All @@ -382,6 +386,13 @@ def __init__(self, dtype: plc.DataType, value: pl.Series) -> None:
data = value.to_arrow()
self.value = data.cast(dtypes.downcast_arrow_lists(data.type))

def get_hash(self) -> int:
"""Compute a hash of the column."""
# This is stricter than necessary, but we only need this hash
# for identity in groupby replacements so it's OK. And this
# way we avoid doing potentially expensive compute.
return hash((type(self), self.dtype, id(self.value)))

def do_evaluate(
self,
df: DataFrame,
Expand All @@ -393,6 +404,10 @@ def do_evaluate(
# datatype of pyarrow array is correct by construction.
return Column(plc.interop.from_arrow(self.value))

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
return AggInfo([])


class Col(Expr):
__slots__ = ("name",)
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def check_agg(agg: expr.Expr) -> int:
return max(GroupBy.check_agg(child) for child in agg.children)
elif isinstance(agg, expr.Agg):
return 1 + max(GroupBy.check_agg(child) for child in agg.children)
elif isinstance(agg, (expr.Len, expr.Col, expr.Literal)):
elif isinstance(agg, (expr.Len, expr.Col, expr.Literal, expr.LiteralColumn)):
return 0
else:
raise NotImplementedError(f"No handler for {agg=}")
Expand Down Expand Up @@ -574,7 +574,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
results = [
req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests
]
return DataFrame([*result_keys, *results]).slice(self.options.slice)
return DataFrame(broadcast(*result_keys, *results)).slice(self.options.slice)


@dataclasses.dataclass
Expand Down
17 changes: 17 additions & 0 deletions python/cudf_polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,20 @@ def test_groupby_nan_minmax_raises(op):
q = df.group_by("key").agg(op(pl.col("value")))

assert_ir_translation_raises(q, NotImplementedError)


@pytest.mark.parametrize("key", [1, pl.col("key1")])
@pytest.mark.parametrize(
"expr",
[
pl.lit(1).alias("value"),
pl.lit([[4, 5, 6]]).alias("value"),
pl.col("float") * (1 - pl.col("int")),
[pl.lit(2).alias("value"), pl.col("float") * 2],
],
)
def test_groupby_literal_in_agg(df, key, expr):
# check_row_order=False doesn't work for list aggregations
# so just sort by the group key
q = df.group_by(key).agg(expr).sort(key, maintain_order=True)
assert_gpu_result_equal(q)

0 comments on commit 8ff27ed

Please sign in to comment.