Skip to content

Commit

Permalink
WIP: making pylibcudf aggregation requests directly
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Apr 16, 2024
1 parent c390fb9 commit b10c74d
Showing 1 changed file with 41 additions and 10 deletions.
51 changes: 41 additions & 10 deletions python/cudf_polars/cudf_polars/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from polars.polars import expr_nodes

from cudf_polars.dataframe import DataFrame
from cudf_polars.utils import sort_order, to_cudf_dtype, to_pylibcudf_dtype
from cudf_polars.utils import (
placeholder_column,
sort_order,
to_cudf_dtype,
to_pylibcudf_dtype,
)

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -504,7 +509,9 @@ def agg_depth(agg, visitor: ExprVisitor) -> int:
# right now, we must run agg_depth first.
def collect_agg(
node: int, context: DataFrame, depth: int, visitor: ExprVisitor
) -> tuple[list[ColumnType | None], list[tuple[str, int]], str]:
) -> tuple[
list[ColumnType | None], list[tuple[plc.aggregation.Aggregation, int]], str
]:
"""
Collect the aggregation requirements of a single aggregation request.
Expand All @@ -522,12 +529,27 @@ def collect_agg(
assert depth <= 1
agg = visitor.node(node)
if isinstance(agg, expr_nodes.Column):
return [context[agg.name]], [("collect_list", node)], agg.name
return (
[context[agg.name]],
[(plc.aggregation.collect_list(), node)],
agg.name,
)
elif isinstance(agg, expr_nodes.Alias):
col, req, _ = collect_agg(agg.expr, context, depth, visitor)
return col, req, agg.name
elif isinstance(agg, expr_nodes.Len):
return [None], [("count_all", node)], "len"
return (
[placeholder_column(context.num_rows())],
[
(
plc.aggregation.count(
null_handling=plc.types.NullPolicy.INCLUDE
),
node,
)
],
"len",
)
elif isinstance(agg, expr_nodes.Agg):
request = agg.name
column, _, name = collect_agg(
Expand All @@ -543,13 +565,19 @@ def collect_agg(
dtype=libcudf.types.size_type_dtype,
).to_pylibcudf(mode="read")
]
request = "collect_list"
request = plc.aggregation.collect_list()
elif request == "implode":
raise NotImplementedError("implode in groupby not implemented")
elif request == "count" and agg.options:
# Include nulls
request = "count_all"
elif request == "count":
request = plc.aggregation.count(
null_handling=plc.types.NullPolicy.INCLUDE
if agg.options
else plc.types.NullPolicy.EXCLUDE
)
column = [None]
else:
# TODO: ensure all options are handled correctly
request = getattr(plc.aggregation, request)()
return column, [(request, node)], name
elif isinstance(agg, expr_nodes.BinaryExpr):
# TODO: no nested agg(binop(agg)) right now
Expand All @@ -560,15 +588,18 @@ def collect_agg(
return [*lcol, *rcol], [*lreq, *rreq], lname
else:
((name, column),) = evaluate_expr(agg, context, visitor).items()
return [column], [("collect_list", node)], name
return [column], [(plc.aggregation.collect_list(), node)], name
else:
raise NotImplementedError


def collect_aggs(
agg_exprs: list[int], context: DataFrame, visitor: ExprVisitor
) -> tuple[
list[ColumnType | None], list[list[str]], list[list[list[int]]], list[str]
list[ColumnType | None],
list[list[plc.aggregation.Aggregation]],
list[list[list[int]]],
list[str],
]:
"""
Collect all the unique aggregation requests.
Expand Down

0 comments on commit b10c74d

Please sign in to comment.