Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix order-preservation in cudf-polars groupby #16907

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,24 +603,39 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests
]
broadcasted = broadcast(*result_keys, *results)
result_keys = broadcasted[: len(result_keys)]
results = broadcasted[len(result_keys) :]
# Handle order preservation of groups
# like cudf classic does
# https://github.com/rapidsai/cudf/blob/5780c4d8fb5afac2e04988a2ff5531f94c22d3a3/python/cudf/cudf/core/groupby/groupby.py#L723-L743
if self.maintain_order and not sorted:
left = plc.stream_compaction.stable_distinct(
# The order we want
want = plc.stream_compaction.stable_distinct(
plc.Table([k.obj for k in keys]),
list(range(group_keys.num_columns())),
plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
right = plc.Table([key.obj for key in result_keys])
_, indices = plc.join.left_join(left, right, plc.types.NullEquality.EQUAL)
# The order we have
have = plc.Table([key.obj for key in broadcasted[: len(keys)]])

# We know an inner join is OK because by construction
# want and have are permutations of each other.
left_order, right_order = plc.join.inner_join(
want, have, plc.types.NullEquality.EQUAL
)
# Now left_order is an arbitrary permutation of the ordering we
# want, and right_order is a matching permutation of the ordering
# we have. To get to the original ordering, we need
# left_order == iota(nrows), with right_order permuted
# appropriately. This can be obtained by sorting
# right_order by left_order.
(right_order,) = plc.sorting.sort_by_key(
plc.Table([right_order]),
plc.Table([left_order]),
[plc.types.Order.ASCENDING],
[plc.types.NullOrder.AFTER],
).columns()
ordered_table = plc.copying.gather(
plc.Table([col.obj for col in broadcasted]),
indices,
right_order,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
)
broadcasted = [
Expand Down
22 changes: 22 additions & 0 deletions python/cudf_polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import itertools

import numpy as np
import pytest

import polars as pl
Expand Down Expand Up @@ -191,3 +192,24 @@ def test_groupby_literal_in_agg(df, key, expr):
def test_groupby_unary_non_pointwise_raises(df, expr):
q = df.group_by("key1").agg(expr)
assert_ir_translation_raises(q, NotImplementedError)


@pytest.mark.parametrize("nrows", [30, 300, 300_000])
@pytest.mark.parametrize("nkeys", [1, 2, 4])
def test_groupby_maintain_order_random(nrows, nkeys, with_nulls):
key_names = [f"key{key}" for key in range(nkeys)]
key_values = [np.random.randint(100, size=nrows) for _ in key_names]
value = np.random.randint(-100, 100, size=nrows)
df = pl.DataFrame(dict(zip(key_names, key_values, strict=True), value=value))
if with_nulls:
df = df.with_columns(
*(
pl.when(pl.col(name) == 1)
.then(None)
.otherwise(pl.col(name))
.alias(name)
for name in key_names
)
)
q = df.lazy().group_by(key_names, maintain_order=True).agg(pl.col("value").sum())
assert_gpu_result_equal(q)
Loading