From 11992465fc8adc9cbeeb09ecd7db5a3f97678667 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 25 Sep 2024 11:24:01 +0000 Subject: [PATCH] Fix order-preservation in cudf-polars groupby When we are requested to maintain order in groupby aggregations we must post-process the result by computing a permutation between the wanted order (of the input keys) and the order returned by the groupby aggregation. To do this, we can perform a join between the two unique key tables. Previously, we assumed that the gather map returned in this join for the left (wanted order) table was the identity. However, this is not guaranteed, in addition to computing the match between the wanted key order and the key order we have, we must also apply the permutation between the left gather map order and the identity. - Closes #16893 --- python/cudf_polars/cudf_polars/dsl/ir.py | 31 ++++++++++++++++++------ python/cudf_polars/tests/test_groupby.py | 22 +++++++++++++++++ 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 8cd56c8ee3a..1c61075be22 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -603,24 +603,39 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests ] broadcasted = broadcast(*result_keys, *results) - result_keys = broadcasted[: len(result_keys)] - results = broadcasted[len(result_keys) :] # Handle order preservation of groups - # like cudf classic does - # https://github.com/rapidsai/cudf/blob/5780c4d8fb5afac2e04988a2ff5531f94c22d3a3/python/cudf/cudf/core/groupby/groupby.py#L723-L743 if self.maintain_order and not sorted: - left = plc.stream_compaction.stable_distinct( + # The order we want + want = plc.stream_compaction.stable_distinct( plc.Table([k.obj for k in keys]), list(range(group_keys.num_columns())), plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST, plc.types.NullEquality.EQUAL, plc.types.NanEquality.ALL_EQUAL, ) - right = plc.Table([key.obj for key in result_keys]) - _, indices = plc.join.left_join(left, right, plc.types.NullEquality.EQUAL) + # The order we have + have = plc.Table([key.obj for key in broadcasted[: len(keys)]]) + + # We know an inner join is OK because by construction + # want and have are permutations of each other. + left_order, right_order = plc.join.inner_join( + want, have, plc.types.NullEquality.EQUAL + ) + # Now left_order is an arbitrary permutation of the ordering we + # want, and right_order is a matching permutation of the ordering + # we have. To get to the original ordering, we need + # left_order == iota(nrows), with right_order permuted + # appropriately. This can be obtained by sorting + # right_order by left_order. + (right_order,) = plc.sorting.sort_by_key( + plc.Table([right_order]), + plc.Table([left_order]), + [plc.types.Order.ASCENDING], + [plc.types.NullOrder.AFTER], + ).columns() ordered_table = plc.copying.gather( plc.Table([col.obj for col in broadcasted]), - indices, + right_order, plc.copying.OutOfBoundsPolicy.DONT_CHECK, ) broadcasted = [ diff --git a/python/cudf_polars/tests/test_groupby.py b/python/cudf_polars/tests/test_groupby.py index 74bf8b9e4e2..1e8246496cd 100644 --- a/python/cudf_polars/tests/test_groupby.py +++ b/python/cudf_polars/tests/test_groupby.py @@ -4,6 +4,7 @@ import itertools +import numpy as np import pytest import polars as pl @@ -191,3 +192,24 @@ def test_groupby_literal_in_agg(df, key, expr): def test_groupby_unary_non_pointwise_raises(df, expr): q = df.group_by("key1").agg(expr) assert_ir_translation_raises(q, NotImplementedError) + + +@pytest.mark.parametrize("nrows", [30, 300, 300_000]) +@pytest.mark.parametrize("nkeys", [1, 2, 4]) +def test_groupby_maintain_order_random(nrows, nkeys, with_nulls): + key_names = [f"key{key}" for key in range(nkeys)] + key_values = [np.random.randint(100, size=nrows) for _ in key_names] + value = np.random.randint(-100, 100, size=nrows) + df = pl.DataFrame(dict(zip(key_names, key_values, strict=True), value=value)) + if with_nulls: + df = df.with_columns( + *( + pl.when(pl.col(name) == 1) + .then(None) + .otherwise(pl.col(name)) + .alias(name) + for name in key_names + ) + ) + q = df.lazy().group_by(key_names, maintain_order=True).agg(pl.col("value").sum()) + assert_gpu_result_equal(q)