Skip to content

Commit

Permalink
Cast count aggs to correct dtype in translation (#16192)
Browse files Browse the repository at this point in the history
Polars default dtypes for some aggregations, particularly count, don't match ours, so insert casts.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #16192
  • Loading branch information
wence- authored Jul 4, 2024
1 parent 769e94f commit aa4033c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
17 changes: 13 additions & 4 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,11 @@ def _(node: pl_expr.Cast, visitor: NodeTraverser, dtype: plc.DataType) -> expr.E
# Push casts into literals so we can handle Cast(Literal(Null))
if isinstance(inner, expr.Literal):
return expr.Literal(dtype, inner.value.cast(plc.interop.to_arrow(dtype)))
else:
return expr.Cast(dtype, inner)
elif isinstance(inner, expr.Cast):
# Translation of Len/Count-agg put in a cast, remove double
# casts if we have one.
(inner,) = inner.children
return expr.Cast(dtype, inner)


@_translate_expr.register
Expand All @@ -443,12 +446,15 @@ def _(node: pl_expr.Column, visitor: NodeTraverser, dtype: plc.DataType) -> expr

@_translate_expr.register
def _(node: pl_expr.Agg, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
return expr.Agg(
value = expr.Agg(
dtype,
node.name,
node.options,
*(translate_expr(visitor, n=n) for n in node.arguments),
)
if value.name == "count" and value.dtype.id() != plc.TypeId.INT32:
return expr.Cast(value.dtype, value)
return value


@_translate_expr.register
Expand All @@ -475,7 +481,10 @@ def _(

@_translate_expr.register
def _(node: pl_expr.Len, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
return expr.Len(dtype)
value = expr.Len(dtype)
if dtype.id() != plc.TypeId.INT32:
return expr.Cast(dtype, value)
return value # pragma: no cover; never reached since polars len has uint32 dtype


def translate_expr(visitor: NodeTraverser, *, n: int) -> expr.Expr:
Expand Down
5 changes: 1 addition & 4 deletions python/cudf_polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ def test_groupby(df: pl.LazyFrame, maintain_order, keys, exprs):
def test_groupby_len(df, keys):
q = df.group_by(*keys).agg(pl.len())

# TODO: polars returns UInt32, libcudf returns Int32
with pytest.raises(AssertionError):
assert_gpu_result_equal(q, check_row_order=False)
assert_gpu_result_equal(q, check_dtypes=False, check_row_order=False)
assert_gpu_result_equal(q, check_row_order=False)


@pytest.mark.parametrize(
Expand Down

0 comments on commit aa4033c

Please sign in to comment.