diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 8cd56c8ee3a..1c61075be22 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -603,24 +603,39 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests ] broadcasted = broadcast(*result_keys, *results) - result_keys = broadcasted[: len(result_keys)] - results = broadcasted[len(result_keys) :] # Handle order preservation of groups - # like cudf classic does - # https://github.com/rapidsai/cudf/blob/5780c4d8fb5afac2e04988a2ff5531f94c22d3a3/python/cudf/cudf/core/groupby/groupby.py#L723-L743 if self.maintain_order and not sorted: - left = plc.stream_compaction.stable_distinct( + # The order we want + want = plc.stream_compaction.stable_distinct( plc.Table([k.obj for k in keys]), list(range(group_keys.num_columns())), plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST, plc.types.NullEquality.EQUAL, plc.types.NanEquality.ALL_EQUAL, ) - right = plc.Table([key.obj for key in result_keys]) - _, indices = plc.join.left_join(left, right, plc.types.NullEquality.EQUAL) + # The order we have + have = plc.Table([key.obj for key in broadcasted[: len(keys)]]) + + # We know an inner join is OK because by construction + # want and have are permutations of each other. + left_order, right_order = plc.join.inner_join( + want, have, plc.types.NullEquality.EQUAL + ) + # Now left_order is an arbitrary permutation of the ordering we + # want, and right_order is a matching permutation of the ordering + # we have. To get to the original ordering, we need + # left_order == iota(nrows), with right_order permuted + # appropriately. This can be obtained by sorting + # right_order by left_order. + (right_order,) = plc.sorting.sort_by_key( + plc.Table([right_order]), + plc.Table([left_order]), + [plc.types.Order.ASCENDING], + [plc.types.NullOrder.AFTER], + ).columns() ordered_table = plc.copying.gather( plc.Table([col.obj for col in broadcasted]), - indices, + right_order, plc.copying.OutOfBoundsPolicy.DONT_CHECK, ) broadcasted = [ diff --git a/python/cudf_polars/tests/test_groupby.py b/python/cudf_polars/tests/test_groupby.py index 74bf8b9e4e2..1e8246496cd 100644 --- a/python/cudf_polars/tests/test_groupby.py +++ b/python/cudf_polars/tests/test_groupby.py @@ -4,6 +4,7 @@ import itertools +import numpy as np import pytest import polars as pl @@ -191,3 +192,24 @@ def test_groupby_literal_in_agg(df, key, expr): def test_groupby_unary_non_pointwise_raises(df, expr): q = df.group_by("key1").agg(expr) assert_ir_translation_raises(q, NotImplementedError) + + +@pytest.mark.parametrize("nrows", [30, 300, 300_000]) +@pytest.mark.parametrize("nkeys", [1, 2, 4]) +def test_groupby_maintain_order_random(nrows, nkeys, with_nulls): + key_names = [f"key{key}" for key in range(nkeys)] + key_values = [np.random.randint(100, size=nrows) for _ in key_names] + value = np.random.randint(-100, 100, size=nrows) + df = pl.DataFrame(dict(zip(key_names, key_values, strict=True), value=value)) + if with_nulls: + df = df.with_columns( + *( + pl.when(pl.col(name) == 1) + .then(None) + .otherwise(pl.col(name)) + .alias(name) + for name in key_names + ) + ) + q = df.lazy().group_by(key_names, maintain_order=True).agg(pl.col("value").sum()) + assert_gpu_result_equal(q)