Skip to content

Commit

Permalink
Fix order-preservation in cudf-polars groupby (rapidsai#16907)
Browse files Browse the repository at this point in the history
When we are requested to maintain order in groupby aggregations we must post-process the result by computing a permutation between the wanted order (of the input keys) and the order returned by the groupby aggregation. To do this, we can perform a join between the two unique key tables. Previously, we assumed that the gather map returned in this join for the left (wanted order) table was the identity. However, this is not guaranteed, in addition to computing the match between the wanted key order and the key order we have, we must also apply the permutation between the left gather map order and the identity.

- Closes rapidsai#16893

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - https://github.com/brandon-b-miller

URL: rapidsai#16907
  • Loading branch information
wence- authored Sep 30, 2024
1 parent e2bcbb8 commit 2b6408b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
31 changes: 23 additions & 8 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,24 +603,39 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests
]
broadcasted = broadcast(*result_keys, *results)
result_keys = broadcasted[: len(result_keys)]
results = broadcasted[len(result_keys) :]
# Handle order preservation of groups
# like cudf classic does
# https://github.com/rapidsai/cudf/blob/5780c4d8fb5afac2e04988a2ff5531f94c22d3a3/python/cudf/cudf/core/groupby/groupby.py#L723-L743
if self.maintain_order and not sorted:
left = plc.stream_compaction.stable_distinct(
# The order we want
want = plc.stream_compaction.stable_distinct(
plc.Table([k.obj for k in keys]),
list(range(group_keys.num_columns())),
plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
right = plc.Table([key.obj for key in result_keys])
_, indices = plc.join.left_join(left, right, plc.types.NullEquality.EQUAL)
# The order we have
have = plc.Table([key.obj for key in broadcasted[: len(keys)]])

# We know an inner join is OK because by construction
# want and have are permutations of each other.
left_order, right_order = plc.join.inner_join(
want, have, plc.types.NullEquality.EQUAL
)
# Now left_order is an arbitrary permutation of the ordering we
# want, and right_order is a matching permutation of the ordering
# we have. To get to the original ordering, we need
# left_order == iota(nrows), with right_order permuted
# appropriately. This can be obtained by sorting
# right_order by left_order.
(right_order,) = plc.sorting.sort_by_key(
plc.Table([right_order]),
plc.Table([left_order]),
[plc.types.Order.ASCENDING],
[plc.types.NullOrder.AFTER],
).columns()
ordered_table = plc.copying.gather(
plc.Table([col.obj for col in broadcasted]),
indices,
right_order,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
)
broadcasted = [
Expand Down
22 changes: 22 additions & 0 deletions python/cudf_polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import itertools

import numpy as np
import pytest

import polars as pl
Expand Down Expand Up @@ -191,3 +192,24 @@ def test_groupby_literal_in_agg(df, key, expr):
def test_groupby_unary_non_pointwise_raises(df, expr):
q = df.group_by("key1").agg(expr)
assert_ir_translation_raises(q, NotImplementedError)


@pytest.mark.parametrize("nrows", [30, 300, 300_000])
@pytest.mark.parametrize("nkeys", [1, 2, 4])
def test_groupby_maintain_order_random(nrows, nkeys, with_nulls):
key_names = [f"key{key}" for key in range(nkeys)]
key_values = [np.random.randint(100, size=nrows) for _ in key_names]
value = np.random.randint(-100, 100, size=nrows)
df = pl.DataFrame(dict(zip(key_names, key_values, strict=True), value=value))
if with_nulls:
df = df.with_columns(
*(
pl.when(pl.col(name) == 1)
.then(None)
.otherwise(pl.col(name))
.alias(name)
for name in key_names
)
)
q = df.lazy().group_by(key_names, maintain_order=True).agg(pl.col("value").sum())
assert_gpu_result_equal(q)

0 comments on commit 2b6408b

Please sign in to comment.