-
Notifications
You must be signed in to change notification settings - Fork 919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic multi-partition GroupBy
support to cuDF-Polars
#17503
Open
rjzamora
wants to merge
15
commits into
rapidsai:branch-25.02
Choose a base branch
from
rjzamora:cudf-polars-multi-groupby
base: branch-25.02
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
f0964a6
basic groupby-aggregation support
rjzamora 1329cf1
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora 11a03f8
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora a9fa486
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora b1224a0
remove GroupbyTree
rjzamora 385f03a
simplify lower
rjzamora 8956215
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora 70b29b2
Merge remote-tracking branch 'upstream/branch-25.02' into cudf-polars…
rjzamora 3f04eca
cleanup
rjzamora e090de5
no cover
rjzamora 24b88f2
tweak error message
rjzamora 161a53b
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora 69f6336
update copyright dates
rjzamora 22cebeb
add test coverage for single-partition
rjzamora 45ac8ec
Merge branch 'branch-25.02' into cudf-polars-multi-groupby
rjzamora File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Parallel GroupBy Logic.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
import pylibcudf as plc | ||
|
||
from cudf_polars.dsl.expr import Agg, BinOp, Cast, Col, Len, NamedExpr | ||
from cudf_polars.dsl.ir import GroupBy, Select | ||
from cudf_polars.dsl.traversal import traversal | ||
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name | ||
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import MutableMapping | ||
|
||
from cudf_polars.dsl.expr import Expr | ||
from cudf_polars.dsl.ir import IR | ||
from cudf_polars.experimental.parallel import LowerIRTransformer | ||
|
||
|
||
# Supported multi-partition aggregations | ||
_GB_AGG_SUPPORTED = ("sum", "count", "mean") | ||
|
||
|
||
@lower_ir_node.register(GroupBy) | ||
def _( | ||
ir: GroupBy, rec: LowerIRTransformer | ||
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: | ||
# Extract child partitioning | ||
child, partition_info = rec(ir.children[0]) | ||
|
||
# Handle single-partition case | ||
if partition_info[child].count == 1: | ||
single_part_node = ir.reconstruct([child]) | ||
partition_info[single_part_node] = partition_info[child] | ||
return single_part_node, partition_info | ||
|
||
# Check group-by keys | ||
if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.keys])): | ||
raise NotImplementedError( | ||
"GroupBy does not support multiple partitions " | ||
f"for these keys:\n{ir.keys}" | ||
) # pragma: no cover | ||
|
||
name_map: MutableMapping[str, Any] = {} | ||
agg_tree: Cast | Agg | None = None | ||
agg_requests_pwise = [] # Partition-wise requests | ||
agg_requests_tree = [] # Tree-node requests | ||
|
||
for ne in ir.agg_requests: | ||
name = ne.name | ||
agg: Expr = ne.value | ||
dtype = agg.dtype | ||
agg = agg.children[0] if isinstance(agg, Cast) else agg | ||
if isinstance(agg, Len): | ||
agg_requests_pwise.append(ne) | ||
agg_requests_tree.append( | ||
NamedExpr( | ||
name, | ||
Cast( | ||
dtype, | ||
Agg(dtype, "sum", None, Col(dtype, name)), | ||
), | ||
) | ||
) | ||
elif isinstance(agg, Agg): | ||
if agg.name not in _GB_AGG_SUPPORTED: | ||
raise NotImplementedError( | ||
"GroupBy does not support multiple partitions " | ||
f"for this expression:\n{agg}" | ||
) | ||
|
||
if agg.name in ("sum", "count"): | ||
agg_requests_pwise.append(ne) | ||
agg_requests_tree.append( | ||
NamedExpr( | ||
name, | ||
Cast( | ||
dtype, | ||
Agg(dtype, "sum", agg.options, Col(dtype, name)), | ||
), | ||
) | ||
) | ||
elif agg.name == "mean": | ||
name_map[name] = {agg.name: {}} | ||
for sub in ["sum", "count"]: | ||
# Partwise | ||
tmp_name = f"{name}__{sub}" | ||
name_map[name][agg.name][sub] = tmp_name | ||
agg_pwise = Agg(dtype, sub, agg.options, *agg.children) | ||
agg_requests_pwise.append(NamedExpr(tmp_name, agg_pwise)) | ||
# Tree | ||
agg_tree = Agg(dtype, "sum", agg.options, Col(dtype, tmp_name)) | ||
agg_requests_tree.append(NamedExpr(tmp_name, agg_tree)) | ||
else: | ||
# Unsupported expression | ||
raise NotImplementedError( | ||
"GroupBy does not support multiple partitions " | ||
f"for this expression:\n{agg}" | ||
) # pragma: no cover | ||
|
||
gb_pwise = GroupBy( | ||
ir.schema, | ||
ir.keys, | ||
agg_requests_pwise, | ||
ir.maintain_order, | ||
ir.options, | ||
child, | ||
) | ||
child_count = partition_info[child].count | ||
partition_info[gb_pwise] = PartitionInfo(count=child_count) | ||
|
||
gb_tree = GroupBy( | ||
ir.schema, | ||
ir.keys, | ||
agg_requests_tree, | ||
ir.maintain_order, | ||
ir.options, | ||
gb_pwise, | ||
) | ||
partition_info[gb_tree] = PartitionInfo(count=1) | ||
|
||
schema = ir.schema | ||
output_exprs = [] | ||
for name, dtype in schema.items(): | ||
agg_mapping = name_map.get(name, None) | ||
if agg_mapping is None: | ||
output_exprs.append(NamedExpr(name, Col(dtype, name))) | ||
elif "mean" in agg_mapping: | ||
mean_cols = agg_mapping["mean"] | ||
output_exprs.append( | ||
NamedExpr( | ||
name, | ||
BinOp( | ||
dtype, | ||
plc.binaryop.BinaryOperator.DIV, | ||
Col(dtype, mean_cols["sum"]), | ||
Col(dtype, mean_cols["count"]), | ||
), | ||
) | ||
) | ||
should_broadcast: bool = False | ||
new_node = Select( | ||
schema, | ||
output_exprs, | ||
should_broadcast, | ||
gb_tree, | ||
) | ||
partition_info[new_node] = PartitionInfo(count=1) | ||
return new_node, partition_info | ||
|
||
|
||
def _tree_node(do_evaluate, batch, *args): | ||
return do_evaluate(*args, _concat(batch)) | ||
rjzamora marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@generate_ir_tasks.register(GroupBy) | ||
def _( | ||
ir: GroupBy, partition_info: MutableMapping[IR, PartitionInfo] | ||
) -> MutableMapping[Any, Any]: | ||
(child,) = ir.children | ||
child_count = partition_info[child].count | ||
child_name = get_key_name(child) | ||
output_count = partition_info[ir].count | ||
|
||
if output_count == child_count: | ||
return { | ||
key: ( | ||
ir.do_evaluate, | ||
*ir._non_child_args, | ||
(child_name, i), | ||
) | ||
for i, key in enumerate(partition_info[ir].keys(ir)) | ||
} | ||
elif output_count != 1: # pragma: no cover | ||
raise ValueError(f"Expected single partition, got {output_count}") | ||
|
||
# Simple N-ary tree reduction | ||
j = 0 | ||
graph: MutableMapping[Any, Any] = {} | ||
n_ary = 32 # TODO: Make this configurable | ||
name = get_key_name(ir) | ||
keys: list[Any] = [(child_name, i) for i in range(child_count)] | ||
while len(keys) > n_ary: | ||
new_keys: list[Any] = [] | ||
for i, k in enumerate(range(0, len(keys), n_ary)): | ||
batch = keys[k : k + n_ary] | ||
graph[(name, j, i)] = ( | ||
_tree_node, | ||
ir.do_evaluate, | ||
batch, | ||
*ir._non_child_args, | ||
) | ||
new_keys.append((name, j, i)) | ||
j += 1 | ||
keys = new_keys | ||
graph[(name, 0)] = ( | ||
_tree_node, | ||
ir.do_evaluate, | ||
keys, | ||
*ir._non_child_args, | ||
) | ||
return graph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
import polars as pl | ||
|
||
from cudf_polars.testing.asserts import assert_gpu_result_equal | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def engine(): | ||
return pl.GPUEngine( | ||
raise_on_fail=True, | ||
executor="dask-experimental", | ||
executor_options={"max_rows_per_partition": 4}, | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def df(): | ||
return pl.LazyFrame( | ||
{ | ||
"x": range(150), | ||
"y": ["cat", "dog", "fish"] * 50, | ||
"z": [1.0, 2.0, 3.0, 4.0, 5.0] * 30, | ||
} | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("op", ["sum", "mean", "len"]) | ||
@pytest.mark.parametrize("keys", [("y",), ("y", "z")]) | ||
def test_groupby(df, engine, op, keys): | ||
q = getattr(df.group_by(*keys), op)() | ||
assert_gpu_result_equal(q, engine=engine, check_row_order=False) | ||
|
||
|
||
@pytest.mark.parametrize("op", ["sum", "mean", "len"]) | ||
@pytest.mark.parametrize("keys", [("y",), ("y", "z")]) | ||
def test_groupby_single_partitions(df, op, keys): | ||
q = getattr(df.group_by(*keys), op)() | ||
assert_gpu_result_equal( | ||
q, | ||
engine=pl.GPUEngine( | ||
raise_on_fail=True, | ||
executor="dask-experimental", | ||
executor_options={"max_rows_per_partition": 1e9}, | ||
), | ||
check_row_order=False, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("op", ["sum", "mean", "len", "count"]) | ||
@pytest.mark.parametrize("keys", [("y",), ("y", "z")]) | ||
def test_groupby_agg(df, engine, op, keys): | ||
q = df.group_by(*keys).agg(getattr(pl.col("x"), op)()) | ||
assert_gpu_result_equal(q, engine=engine, check_row_order=False) | ||
|
||
|
||
def test_groupby_raises(df, engine): | ||
q = df.group_by("y").median() | ||
with pytest.raises( | ||
pl.exceptions.ComputeError, | ||
match="NotImplementedError", | ||
): | ||
assert_gpu_result_equal(q, engine=engine, check_row_order=False) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to think about this (and possibly reorganise what we're doing in the single-partition case) to make this easier to handle.
For example, I think it is going to do the wrong thing for
.agg(a.max() + b.min())
I think what you're trying to do here is turn a
GroupBy(df, keys, aggs)
intoReduce(LocalGroupBy(df, keys, agg_exprs), keys, transformed_aggs)
And what does this look like, I think once we've determined the "leaf" aggregations we're performing (e.g.
col.max()
) then we must concat and combine to get the full leaf aggregations, followed by evaluation of the column expressions that produce the final result.So suppose we have determined what the leaf aggs are, and then what the post-aggregation expressions are, for a single-partition this is effectively
Select(GroupBy(df, keys, leaf_aggs), keys, post_agg_exprs)
wherepost_agg_exprs
are all guaranteed elementwise (for now).thought: Would it be easier for you here if the
GroupBy
IR nodes really only held aggregation expressions that are "leaf" aggregations (with the post-processing done in aSelect
)?I think it would, because then the transform becomes something like:
Where
groupbycombine
emits the tree-reduction tasks with the post aggregations.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure the answer is "yes" :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick follow-up: I totally agree that we probably want to revise the upstream
GroupBy
design to make the decomposition here a bit simpler. With that said, I don't think we are doing anything "wrong" here. Rather, the code would just need to become unnecessarily messy if we wanted to do much more than "simple" mean/count/min/max aggregations.We won't do the "wrong" thing here - We will just raise an error. E.g.: