From aa4033c5fe0be9e3d235d5722f1030c60b04e34d Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 4 Jul 2024 10:10:02 +0100 Subject: [PATCH] Cast count aggs to correct dtype in translation (#16192) Polars default dtypes for some aggregations, particularly count, don't match ours, so insert casts. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/16192 --- python/cudf_polars/cudf_polars/dsl/translate.py | 17 +++++++++++++---- python/cudf_polars/tests/test_groupby.py | 5 +---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index a2fdb3c3d79..0019b3aa98a 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -432,8 +432,11 @@ def _(node: pl_expr.Cast, visitor: NodeTraverser, dtype: plc.DataType) -> expr.E # Push casts into literals so we can handle Cast(Literal(Null)) if isinstance(inner, expr.Literal): return expr.Literal(dtype, inner.value.cast(plc.interop.to_arrow(dtype))) - else: - return expr.Cast(dtype, inner) + elif isinstance(inner, expr.Cast): + # Translation of Len/Count-agg put in a cast, remove double + # casts if we have one. + (inner,) = inner.children + return expr.Cast(dtype, inner) @_translate_expr.register @@ -443,12 +446,15 @@ def _(node: pl_expr.Column, visitor: NodeTraverser, dtype: plc.DataType) -> expr @_translate_expr.register def _(node: pl_expr.Agg, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: - return expr.Agg( + value = expr.Agg( dtype, node.name, node.options, *(translate_expr(visitor, n=n) for n in node.arguments), ) + if value.name == "count" and value.dtype.id() != plc.TypeId.INT32: + return expr.Cast(value.dtype, value) + return value @_translate_expr.register @@ -475,7 +481,10 @@ def _( @_translate_expr.register def _(node: pl_expr.Len, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: - return expr.Len(dtype) + value = expr.Len(dtype) + if dtype.id() != plc.TypeId.INT32: + return expr.Cast(dtype, value) + return value # pragma: no cover; never reached since polars len has uint32 dtype def translate_expr(visitor: NodeTraverser, *, n: int) -> expr.Expr: diff --git a/python/cudf_polars/tests/test_groupby.py b/python/cudf_polars/tests/test_groupby.py index aefad59eb91..8a6732b7063 100644 --- a/python/cudf_polars/tests/test_groupby.py +++ b/python/cudf_polars/tests/test_groupby.py @@ -83,10 +83,7 @@ def test_groupby(df: pl.LazyFrame, maintain_order, keys, exprs): def test_groupby_len(df, keys): q = df.group_by(*keys).agg(pl.len()) - # TODO: polars returns UInt32, libcudf returns Int32 - with pytest.raises(AssertionError): - assert_gpu_result_equal(q, check_row_order=False) - assert_gpu_result_equal(q, check_dtypes=False, check_row_order=False) + assert_gpu_result_equal(q, check_row_order=False) @pytest.mark.parametrize(