Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic multi-partition GroupBy support to cuDF-Polars #17503

Open
wants to merge 15 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Parallel GroupBy Logic."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import pylibcudf as plc

from cudf_polars.dsl.expr import Agg, BinOp, Cast, Col, Len, NamedExpr
from cudf_polars.dsl.ir import GroupBy, Select
from cudf_polars.dsl.traversal import traversal
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node

if TYPE_CHECKING:
from collections.abc import MutableMapping

from cudf_polars.dsl.expr import Expr
from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.parallel import LowerIRTransformer


# Supported multi-partition aggregations
_GB_AGG_SUPPORTED = ("sum", "count", "mean")


@lower_ir_node.register(GroupBy)
def _(
ir: GroupBy, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Extract child partitioning
child, partition_info = rec(ir.children[0])

# Handle single-partition case
if partition_info[child].count == 1:
single_part_node = ir.reconstruct([child])
partition_info[single_part_node] = partition_info[child]
return single_part_node, partition_info

# Check group-by keys
if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.keys])):
raise NotImplementedError(
"GroupBy does not support multiple partitions "
f"for these keys:\n{ir.keys}"
) # pragma: no cover

name_map: MutableMapping[str, Any] = {}
agg_tree: Cast | Agg | None = None
agg_requests_pwise = [] # Partition-wise requests
agg_requests_tree = [] # Tree-node requests

for ne in ir.agg_requests:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to think about this (and possibly reorganise what we're doing in the single-partition case) to make this easier to handle.

For example, I think it is going to do the wrong thing for .agg(a.max() + b.min())

I think what you're trying to do here is turn a GroupBy(df, keys, aggs) into Reduce(LocalGroupBy(df, keys, agg_exprs), keys, transformed_aggs)

And what does this look like, I think once we've determined the "leaf" aggregations we're performing (e.g. col.max()) then we must concat and combine to get the full leaf aggregations, followed by evaluation of the column expressions that produce the final result.

So suppose we have determined what the leaf aggs are, and then what the post-aggregation expressions are, for a single-partition this is effectively Select(GroupBy(df, keys, leaf_aggs), keys, post_agg_exprs) where post_agg_exprs are all guaranteed elementwise (for now).

thought: Would it be easier for you here if the GroupBy IR nodes really only held aggregation expressions that are "leaf" aggregations (with the post-processing done in a Select)?

I think it would, because then the transform becomes something like:

Select(
   GroupByCombine(GroupBy(df, keys, leaf_aggs), keys, post_aggs),
   keys, post_agg_exprs
)

Where groupbycombine emits the tree-reduction tasks with the post aggregations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: Would it be easier for you here if the GroupBy IR nodes really only held aggregation expressions that are "leaf" aggregations (with the post-processing done in a Select)?

I'm pretty sure the answer is "yes" :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick follow-up: I totally agree that we probably want to revise the upstream GroupBy design to make the decomposition here a bit simpler. With that said, I don't think we are doing anything "wrong" here. Rather, the code would just need to become unnecessarily messy if we wanted to do much more than "simple" mean/count/min/max aggregations.

For example, I think it is going to do the wrong thing for .agg(a.max() + b.min())

We won't do the "wrong" thing here - We will just raise an error. E.g.:

polars.exceptions.ComputeError: NotImplementedError: GroupBy does not support multiple partitions for this expression:
BinOp(<pylibcudf.types.DataType object at 0x7f06ebcc63b0>, <binary_operator.ADD: 0>, Cast(<pylibcudf.types.DataType object at 0x7f06ebcc63b0>, Agg(<pylibcudf.types.DataType object at 0x7f06ebcc6370>, 'max', False, Col(<pylibcudf.types.DataType object at 0x7f06ebcc6370>, 'x'))), Agg(<pylibcudf.types.DataType object at 0x7f06ebcc63b0>, 'max', False, Col(<pylibcudf.types.DataType object at 0x7f06ebcc63b0>, 'z')))

name = ne.name
agg: Expr = ne.value
dtype = agg.dtype
agg = agg.children[0] if isinstance(agg, Cast) else agg
if isinstance(agg, Len):
agg_requests_pwise.append(ne)
agg_requests_tree.append(
NamedExpr(
name,
Cast(
dtype,
Agg(dtype, "sum", None, Col(dtype, name)),
),
)
)
elif isinstance(agg, Agg):
if agg.name not in _GB_AGG_SUPPORTED:
raise NotImplementedError(
"GroupBy does not support multiple partitions "
f"for this expression:\n{agg}"
)

if agg.name in ("sum", "count"):
agg_requests_pwise.append(ne)
agg_requests_tree.append(
NamedExpr(
name,
Cast(
dtype,
Agg(dtype, "sum", agg.options, Col(dtype, name)),
),
)
)
elif agg.name == "mean":
name_map[name] = {agg.name: {}}
for sub in ["sum", "count"]:
# Partwise
tmp_name = f"{name}__{sub}"
name_map[name][agg.name][sub] = tmp_name
agg_pwise = Agg(dtype, sub, agg.options, *agg.children)
agg_requests_pwise.append(NamedExpr(tmp_name, agg_pwise))
# Tree
agg_tree = Agg(dtype, "sum", agg.options, Col(dtype, tmp_name))
agg_requests_tree.append(NamedExpr(tmp_name, agg_tree))
else:
# Unsupported expression
raise NotImplementedError(
"GroupBy does not support multiple partitions "
f"for this expression:\n{agg}"
) # pragma: no cover

gb_pwise = GroupBy(
ir.schema,
ir.keys,
agg_requests_pwise,
ir.maintain_order,
ir.options,
child,
)
child_count = partition_info[child].count
partition_info[gb_pwise] = PartitionInfo(count=child_count)

gb_tree = GroupBy(
ir.schema,
ir.keys,
agg_requests_tree,
ir.maintain_order,
ir.options,
gb_pwise,
)
partition_info[gb_tree] = PartitionInfo(count=1)

schema = ir.schema
output_exprs = []
for name, dtype in schema.items():
agg_mapping = name_map.get(name, None)
if agg_mapping is None:
output_exprs.append(NamedExpr(name, Col(dtype, name)))
elif "mean" in agg_mapping:
mean_cols = agg_mapping["mean"]
output_exprs.append(
NamedExpr(
name,
BinOp(
dtype,
plc.binaryop.BinaryOperator.DIV,
Col(dtype, mean_cols["sum"]),
Col(dtype, mean_cols["count"]),
),
)
)
should_broadcast: bool = False
new_node = Select(
schema,
output_exprs,
should_broadcast,
gb_tree,
)
partition_info[new_node] = PartitionInfo(count=1)
return new_node, partition_info


def _tree_node(do_evaluate, batch, *args):
return do_evaluate(*args, _concat(batch))
rjzamora marked this conversation as resolved.
Show resolved Hide resolved


@generate_ir_tasks.register(GroupBy)
def _(
ir: GroupBy, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
(child,) = ir.children
child_count = partition_info[child].count
child_name = get_key_name(child)
output_count = partition_info[ir].count

if output_count == child_count:
return {
key: (
ir.do_evaluate,
*ir._non_child_args,
(child_name, i),
)
for i, key in enumerate(partition_info[ir].keys(ir))
}
elif output_count != 1: # pragma: no cover
raise ValueError(f"Expected single partition, got {output_count}")

# Simple N-ary tree reduction
j = 0
graph: MutableMapping[Any, Any] = {}
n_ary = 32 # TODO: Make this configurable
name = get_key_name(ir)
keys: list[Any] = [(child_name, i) for i in range(child_count)]
while len(keys) > n_ary:
new_keys: list[Any] = []
for i, k in enumerate(range(0, len(keys), n_ary)):
batch = keys[k : k + n_ary]
graph[(name, j, i)] = (
_tree_node,
ir.do_evaluate,
batch,
*ir._non_child_args,
)
new_keys.append((name, j, i))
j += 1
keys = new_keys
graph[(name, 0)] = (
_tree_node,
ir.do_evaluate,
keys,
*ir._non_child_args,
)
return graph
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Multi-partition Dask execution."""

Expand All @@ -9,6 +9,7 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

import cudf_polars.experimental.groupby
import cudf_polars.experimental.io
import cudf_polars.experimental.select # noqa: F401
from cudf_polars.dsl.ir import IR, Cache, Filter, HStack, Projection, Select, Union
Expand Down
68 changes: 68 additions & 0 deletions python/cudf_polars/tests/experimental/test_groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(scope="module")
def engine():
return pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
executor_options={"max_rows_per_partition": 4},
)


@pytest.fixture(scope="module")
def df():
return pl.LazyFrame(
{
"x": range(150),
"y": ["cat", "dog", "fish"] * 50,
"z": [1.0, 2.0, 3.0, 4.0, 5.0] * 30,
}
)


@pytest.mark.parametrize("op", ["sum", "mean", "len"])
@pytest.mark.parametrize("keys", [("y",), ("y", "z")])
def test_groupby(df, engine, op, keys):
q = getattr(df.group_by(*keys), op)()
assert_gpu_result_equal(q, engine=engine, check_row_order=False)


@pytest.mark.parametrize("op", ["sum", "mean", "len"])
@pytest.mark.parametrize("keys", [("y",), ("y", "z")])
def test_groupby_single_partitions(df, op, keys):
q = getattr(df.group_by(*keys), op)()
assert_gpu_result_equal(
q,
engine=pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
executor_options={"max_rows_per_partition": 1e9},
),
check_row_order=False,
)


@pytest.mark.parametrize("op", ["sum", "mean", "len", "count"])
@pytest.mark.parametrize("keys", [("y",), ("y", "z")])
def test_groupby_agg(df, engine, op, keys):
q = df.group_by(*keys).agg(getattr(pl.col("x"), op)())
assert_gpu_result_equal(q, engine=engine, check_row_order=False)


def test_groupby_raises(df, engine):
q = df.group_by("y").median()
with pytest.raises(
pl.exceptions.ComputeError,
match="NotImplementedError",
):
assert_gpu_result_equal(q, engine=engine, check_row_order=False)
Loading