Skip to content

Commit

Permalink
Implement translation for some unary functions and a single datetime …
Browse files Browse the repository at this point in the history
…extraction (#16173)

- Closes #16169

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Thomas Li (https://github.com/lithomas1)

URL: #16173
  • Loading branch information
wence- authored Jul 5, 2024
1 parent 7dd6945 commit c978181
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 3 deletions.
124 changes: 124 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"Col",
"BooleanFunction",
"StringFunction",
"TemporalFunction",
"Sort",
"SortBy",
"Gather",
Expand Down Expand Up @@ -815,6 +816,129 @@ def do_evaluate(
) # pragma: no cover; handled by init raising


class TemporalFunction(Expr):
__slots__ = ("name", "options", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self,
dtype: plc.DataType,
name: pl_expr.TemporalFunction,
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.options = options
self.name = name
self.children = children
if self.name != pl_expr.TemporalFunction.Year:
raise NotImplementedError(f"String function {self.name}")

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
columns = [
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
]
if self.name == pl_expr.TemporalFunction.Year:
(column,) = columns
return Column(plc.datetime.extract_year(column.obj))
raise NotImplementedError(
f"TemporalFunction {self.name}"
) # pragma: no cover; init trips first


class UnaryFunction(Expr):
__slots__ = ("name", "options", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self, dtype: plc.DataType, name: str, options: tuple[Any, ...], *children: Expr
) -> None:
super().__init__(dtype)
self.name = name
self.options = options
self.children = children
if self.name not in ("round", "unique"):
raise NotImplementedError(f"Unary function {name=}")

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name == "round":
(decimal_places,) = self.options
(values,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
return Column(
plc.round.round(
values.obj, decimal_places, plc.round.RoundingMethod.HALF_UP
)
).sorted_like(values)
elif self.name == "unique":
(maintain_order,) = self.options
(values,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
# Only one column, so keep_any is the same as keep_first
# for stable distinct
keep = plc.stream_compaction.DuplicateKeepOption.KEEP_ANY
if values.is_sorted:
maintain_order = True
result = plc.stream_compaction.unique(
plc.Table([values.obj]),
[0],
keep,
plc.types.NullEquality.EQUAL,
)
else:
distinct = (
plc.stream_compaction.stable_distinct
if maintain_order
else plc.stream_compaction.distinct
)
result = distinct(
plc.Table([values.obj]),
[0],
keep,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
(column,) = result.columns()
if maintain_order:
return Column(column).sorted_like(values)
return Column(column)
raise NotImplementedError(
f"Unimplemented unary function {self.name=}"
) # pragma: no cover; init trips first

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
if depth == 1:
# inside aggregation, need to pre-evaluate, groupby
# construction has checked that we don't have nested aggs,
# so stop the recursion and return ourselves for pre-eval
return AggInfo([(self, plc.aggregation.collect_list(), self)])
else:
(child,) = self.children
return child.collect_agg(depth=depth)


class Sort(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def check_agg(agg: expr.Expr) -> int:
NotImplementedError
For unsupported expression nodes.
"""
if isinstance(agg, (expr.BinOp, expr.Cast)):
if isinstance(agg, (expr.BinOp, expr.Cast, expr.UnaryFunction)):
return max(GroupBy.check_agg(child) for child in agg.children)
elif isinstance(agg, expr.Agg):
return 1 + max(GroupBy.check_agg(child) for child in agg.children)
Expand Down
19 changes: 17 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,23 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex
options,
*(translate_expr(visitor, n=n) for n in node.input),
)
else:
raise NotImplementedError(f"No handler for Expr function node with {name=}")
elif isinstance(name, pl_expr.TemporalFunction):
return expr.TemporalFunction(
dtype,
name,
options,
*(translate_expr(visitor, n=n) for n in node.input),
)
elif isinstance(name, str):
return expr.UnaryFunction(
dtype,
name,
options,
*(translate_expr(visitor, n=n) for n in node.input),
)
raise NotImplementedError(
f"No handler for Expr function node with {name=}"
) # pragma: no cover; polars raises on the rust side for now


@_translate_expr.register
Expand Down
28 changes: 28 additions & 0 deletions python/cudf_polars/tests/expressions/test_datetime_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import datetime
from operator import methodcaller

import pytest

import polars as pl
Expand Down Expand Up @@ -32,3 +35,28 @@ def test_datetime_dataframe_scan(dtype):

query = ldf.select(pl.col("b"), pl.col("a"))
assert_gpu_result_equal(query)


@pytest.mark.parametrize(
"field",
[
methodcaller("year"),
pytest.param(
methodcaller("day"),
marks=pytest.mark.xfail(reason="day extraction not implemented"),
),
],
)
def test_datetime_extract(field):
ldf = pl.LazyFrame(
{"dates": [datetime.date(2024, 1, 1), datetime.date(2024, 10, 11)]}
)
q = ldf.select(field(pl.col("dates").dt))

with pytest.raises(AssertionError):
# polars produces int32, libcudf produces int16 for the year extraction
# libcudf can lose data here.
# https://github.com/rapidsai/cudf/issues/16196
assert_gpu_result_equal(q)

assert_gpu_result_equal(q, check_dtypes=False)
32 changes: 32 additions & 0 deletions python/cudf_polars/tests/expressions/test_round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import math

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(params=[pl.Float32, pl.Float64])
def dtype(request):
return request.param


@pytest.fixture
def df(dtype, with_nulls):
a = [-math.e, 10, 22.5, 1.5, 2.5, -1.5, math.pi, 8]
if with_nulls:
a[2] = None
a[-1] = None
return pl.LazyFrame({"a": a}, schema={"a": dtype})


@pytest.mark.parametrize("decimals", [0, 2, 4])
def test_round(df, decimals):
q = df.select(pl.col("a").round(decimals=decimals))

assert_gpu_result_equal(q, check_exact=False)
24 changes: 24 additions & 0 deletions python/cudf_polars/tests/expressions/test_unique.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize("maintain_order", [False, True], ids=["unstable", "stable"])
@pytest.mark.parametrize("pre_sorted", [False, True], ids=["unsorted", "sorted"])
def test_unique(maintain_order, pre_sorted):
ldf = pl.DataFrame(
{
"b": [1.5, 2.5, None, 1.5, 3, float("nan"), 3],
}
).lazy()
if pre_sorted:
ldf = ldf.sort("b")

query = ldf.select(pl.col("b").unique(maintain_order=maintain_order))
assert_gpu_result_equal(query, check_row_order=maintain_order)
2 changes: 2 additions & 0 deletions python/cudf_polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def keys(request):
[pl.col("float").max() - pl.col("int").min()],
[pl.col("float").mean(), pl.col("int").std()],
[(pl.col("float") - pl.lit(2)).max()],
[pl.col("float").sum().round(decimals=1)],
[pl.col("float").round(decimals=1).sum()],
],
ids=lambda aggs: "-".join(map(str, aggs)),
)
Expand Down

0 comments on commit c978181

Please sign in to comment.