Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement translation for some unary functions and a single datetime extraction #16173

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"Col",
"BooleanFunction",
"StringFunction",
"TemporalFunction",
"Sort",
"SortBy",
"Gather",
Expand Down Expand Up @@ -779,6 +780,129 @@ def do_evaluate(
) # pragma: no cover; handled by init raising


class TemporalFunction(Expr):
__slots__ = ("name", "options", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self,
dtype: plc.DataType,
name: pl_expr.TemporalFunction,
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.options = options
self.name = name
self.children = children
if self.name != pl_expr.TemporalFunction.Year:
raise NotImplementedError(f"String function {self.name}")

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
]
if self.name == pl_expr.TemporalFunction.Year:
(column,) = columns
return Column(plc.datetime.extract_year(column.obj))
raise NotImplementedError(
f"TemporalFunction {self.name}"
) # pragma: no cover; init trips first


class UnaryFunction(Expr):
__slots__ = ("name", "options", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self, dtype: plc.DataType, name: str, options: tuple[Any, ...], *children: Expr
) -> None:
super().__init__(dtype)
self.name = name
self.options = options
self.children = children
if self.name not in ("round", "unique"):
raise NotImplementedError(f"Unary function {name=}")

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name == "round":
(decimal_places,) = self.options
(values,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
return Column(
plc.round.round(
values.obj, decimal_places, plc.round.RoundingMethod.HALF_UP
)
).sorted_like(values)
elif self.name == "unique":
(maintain_order,) = self.options
(values,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
# Only one column, so keep_any is the same as keep_first
# for stable distinct
keep = plc.stream_compaction.DuplicateKeepOption.KEEP_ANY
if values.is_sorted:
maintain_order = True
result = plc.stream_compaction.unique(
plc.Table([values.obj]),
[0],
keep,
plc.types.NullEquality.EQUAL,
)
else:
distinct = (
plc.stream_compaction.stable_distinct
if maintain_order
else plc.stream_compaction.distinct
)
result = distinct(
plc.Table([values.obj]),
[0],
keep,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
(column,) = result.columns()
if maintain_order:
return Column(column).sorted_like(values)
return Column(column)
raise NotImplementedError(
f"Unimplemented unary function {self.name=}"
) # pragma: no cover; init trips first

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
if depth == 1:
# inside aggregation, need to pre-evaluate, groupby
# construction has checked that we don't have nested aggs,
# so stop the recursion and return ourselves for pre-eval
return AggInfo([(self, plc.aggregation.collect_list(), self)])
else:
(child,) = self.children
return child.collect_agg(depth=depth)


class Sort(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def check_agg(agg: expr.Expr) -> int:
NotImplementedError
For unsupported expression nodes.
"""
if isinstance(agg, (expr.BinOp, expr.Cast)):
if isinstance(agg, (expr.BinOp, expr.Cast, expr.UnaryFunction)):
return max(GroupBy.check_agg(child) for child in agg.children)
elif isinstance(agg, expr.Agg):
return 1 + max(GroupBy.check_agg(child) for child in agg.children)
Expand Down
19 changes: 17 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,23 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex
options,
*(translate_expr(visitor, n=n) for n in node.input),
)
else:
raise NotImplementedError(f"No handler for Expr function node with {name=}")
elif isinstance(name, pl_expr.TemporalFunction):
return expr.TemporalFunction(
dtype,
name,
options,
*(translate_expr(visitor, n=n) for n in node.input),
)
elif isinstance(name, str):
return expr.UnaryFunction(
dtype,
name,
options,
*(translate_expr(visitor, n=n) for n in node.input),
)
raise NotImplementedError(
f"No handler for Expr function node with {name=}"
) # pragma: no cover; polars raises on the rust side for now


@_translate_expr.register
Expand Down
28 changes: 28 additions & 0 deletions python/cudf_polars/tests/expressions/test_datetime_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import datetime
from operator import methodcaller

import pytest

import polars as pl
Expand Down Expand Up @@ -32,3 +35,28 @@ def test_datetime_dataframe_scan(dtype):

query = ldf.select(pl.col("b"), pl.col("a"))
assert_gpu_result_equal(query)


@pytest.mark.parametrize(
"field",
[
methodcaller("year"),
pytest.param(
methodcaller("day"),
marks=pytest.mark.xfail(reason="day extraction not implemented"),
),
],
)
def test_datetime_extract(field):
ldf = pl.LazyFrame(
{"dates": [datetime.date(2024, 1, 1), datetime.date(2024, 10, 11)]}
)
q = ldf.select(field(pl.col("dates").dt))

with pytest.raises(AssertionError):
# polars produces int32, libcudf produces int16 for the year extraction
# libcudf can lose data here.
# https://github.com/rapidsai/cudf/issues/16196
assert_gpu_result_equal(q)
wence- marked this conversation as resolved.
Show resolved Hide resolved

assert_gpu_result_equal(q, check_dtypes=False)
32 changes: 32 additions & 0 deletions python/cudf_polars/tests/expressions/test_round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import math

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(params=[pl.Float32, pl.Float64])
def dtype(request):
return request.param


@pytest.fixture
def df(dtype, with_nulls):
a = [-math.e, 10, 22.5, 1.5, 2.5, -1.5, math.pi, 8]
if with_nulls:
a[2] = None
a[-1] = None
return pl.LazyFrame({"a": a}, schema={"a": dtype})


@pytest.mark.parametrize("decimals", [0, 2, 4])
def test_round(df, decimals):
q = df.select(pl.col("a").round(decimals=decimals))

assert_gpu_result_equal(q, check_exact=False)
24 changes: 24 additions & 0 deletions python/cudf_polars/tests/expressions/test_unique.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize("maintain_order", [False, True], ids=["unstable", "stable"])
@pytest.mark.parametrize("pre_sorted", [False, True], ids=["unsorted", "sorted"])
def test_unique(maintain_order, pre_sorted):
ldf = pl.DataFrame(
{
"b": [1.5, 2.5, None, 1.5, 3, float("nan"), 3],
}
).lazy()
if pre_sorted:
ldf = ldf.sort("b")

query = ldf.select(pl.col("b").unique(maintain_order=maintain_order))
assert_gpu_result_equal(query, check_row_order=maintain_order)
2 changes: 2 additions & 0 deletions python/cudf_polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def keys(request):
[pl.col("float").max() - pl.col("int").min()],
[pl.col("float").mean(), pl.col("int").std()],
[(pl.col("float") - pl.lit(2)).max()],
[pl.col("float").sum().round(decimals=1)],
[pl.col("float").round(decimals=1).sum()],
],
ids=lambda aggs: "-".join(map(str, aggs)),
)
Expand Down
Loading