From c978181a3a721ed75cf016c6f083648c65bd24cd Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 5 Jul 2024 16:11:07 +0100 Subject: [PATCH] Implement translation for some unary functions and a single datetime extraction (#16173) - Closes #16169 Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Thomas Li (https://github.com/lithomas1) URL: https://github.com/rapidsai/cudf/pull/16173 --- python/cudf_polars/cudf_polars/dsl/expr.py | 124 ++++++++++++++++++ python/cudf_polars/cudf_polars/dsl/ir.py | 2 +- .../cudf_polars/cudf_polars/dsl/translate.py | 19 ++- .../tests/expressions/test_datetime_basic.py | 28 ++++ .../tests/expressions/test_round.py | 32 +++++ .../tests/expressions/test_unique.py | 24 ++++ python/cudf_polars/tests/test_groupby.py | 2 + 7 files changed, 228 insertions(+), 3 deletions(-) create mode 100644 python/cudf_polars/tests/expressions/test_round.py create mode 100644 python/cudf_polars/tests/expressions/test_unique.py diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 69bc85b109d..93cb9db7cbd 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -44,6 +44,7 @@ "Col", "BooleanFunction", "StringFunction", + "TemporalFunction", "Sort", "SortBy", "Gather", @@ -815,6 +816,129 @@ def do_evaluate( ) # pragma: no cover; handled by init raising +class TemporalFunction(Expr): + __slots__ = ("name", "options", "children") + _non_child = ("dtype", "name", "options") + children: tuple[Expr, ...] + + def __init__( + self, + dtype: plc.DataType, + name: pl_expr.TemporalFunction, + options: tuple[Any, ...], + *children: Expr, + ) -> None: + super().__init__(dtype) + self.options = options + self.name = name + self.children = children + if self.name != pl_expr.TemporalFunction.Year: + raise NotImplementedError(f"String function {self.name}") + + def do_evaluate( + self, + df: DataFrame, + *, + context: ExecutionContext = ExecutionContext.FRAME, + mapping: Mapping[Expr, Column] | None = None, + ) -> Column: + """Evaluate this expression given a dataframe for context.""" + columns = [ + child.evaluate(df, context=context, mapping=mapping) + for child in self.children + ] + if self.name == pl_expr.TemporalFunction.Year: + (column,) = columns + return Column(plc.datetime.extract_year(column.obj)) + raise NotImplementedError( + f"TemporalFunction {self.name}" + ) # pragma: no cover; init trips first + + +class UnaryFunction(Expr): + __slots__ = ("name", "options", "children") + _non_child = ("dtype", "name", "options") + children: tuple[Expr, ...] + + def __init__( + self, dtype: plc.DataType, name: str, options: tuple[Any, ...], *children: Expr + ) -> None: + super().__init__(dtype) + self.name = name + self.options = options + self.children = children + if self.name not in ("round", "unique"): + raise NotImplementedError(f"Unary function {name=}") + + def do_evaluate( + self, + df: DataFrame, + *, + context: ExecutionContext = ExecutionContext.FRAME, + mapping: Mapping[Expr, Column] | None = None, + ) -> Column: + """Evaluate this expression given a dataframe for context.""" + if self.name == "round": + (decimal_places,) = self.options + (values,) = ( + child.evaluate(df, context=context, mapping=mapping) + for child in self.children + ) + return Column( + plc.round.round( + values.obj, decimal_places, plc.round.RoundingMethod.HALF_UP + ) + ).sorted_like(values) + elif self.name == "unique": + (maintain_order,) = self.options + (values,) = ( + child.evaluate(df, context=context, mapping=mapping) + for child in self.children + ) + # Only one column, so keep_any is the same as keep_first + # for stable distinct + keep = plc.stream_compaction.DuplicateKeepOption.KEEP_ANY + if values.is_sorted: + maintain_order = True + result = plc.stream_compaction.unique( + plc.Table([values.obj]), + [0], + keep, + plc.types.NullEquality.EQUAL, + ) + else: + distinct = ( + plc.stream_compaction.stable_distinct + if maintain_order + else plc.stream_compaction.distinct + ) + result = distinct( + plc.Table([values.obj]), + [0], + keep, + plc.types.NullEquality.EQUAL, + plc.types.NanEquality.ALL_EQUAL, + ) + (column,) = result.columns() + if maintain_order: + return Column(column).sorted_like(values) + return Column(column) + raise NotImplementedError( + f"Unimplemented unary function {self.name=}" + ) # pragma: no cover; init trips first + + def collect_agg(self, *, depth: int) -> AggInfo: + """Collect information about aggregations in groupbys.""" + if depth == 1: + # inside aggregation, need to pre-evaluate, groupby + # construction has checked that we don't have nested aggs, + # so stop the recursion and return ourselves for pre-eval + return AggInfo([(self, plc.aggregation.collect_list(), self)]) + else: + (child,) = self.children + return child.collect_agg(depth=depth) + + class Sort(Expr): __slots__ = ("options", "children") _non_child = ("dtype", "options") diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 31a0be004ea..6b552642e88 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -433,7 +433,7 @@ def check_agg(agg: expr.Expr) -> int: NotImplementedError For unsupported expression nodes. """ - if isinstance(agg, (expr.BinOp, expr.Cast)): + if isinstance(agg, (expr.BinOp, expr.Cast, expr.UnaryFunction)): return max(GroupBy.check_agg(child) for child in agg.children) elif isinstance(agg, expr.Agg): return 1 + max(GroupBy.check_agg(child) for child in agg.children) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 0019b3aa98a..5a1e682abe7 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -361,8 +361,23 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex options, *(translate_expr(visitor, n=n) for n in node.input), ) - else: - raise NotImplementedError(f"No handler for Expr function node with {name=}") + elif isinstance(name, pl_expr.TemporalFunction): + return expr.TemporalFunction( + dtype, + name, + options, + *(translate_expr(visitor, n=n) for n in node.input), + ) + elif isinstance(name, str): + return expr.UnaryFunction( + dtype, + name, + options, + *(translate_expr(visitor, n=n) for n in node.input), + ) + raise NotImplementedError( + f"No handler for Expr function node with {name=}" + ) # pragma: no cover; polars raises on the rust side for now @_translate_expr.register diff --git a/python/cudf_polars/tests/expressions/test_datetime_basic.py b/python/cudf_polars/tests/expressions/test_datetime_basic.py index 6ba2a1dce1e..218101bf87c 100644 --- a/python/cudf_polars/tests/expressions/test_datetime_basic.py +++ b/python/cudf_polars/tests/expressions/test_datetime_basic.py @@ -2,6 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import datetime +from operator import methodcaller + import pytest import polars as pl @@ -32,3 +35,28 @@ def test_datetime_dataframe_scan(dtype): query = ldf.select(pl.col("b"), pl.col("a")) assert_gpu_result_equal(query) + + +@pytest.mark.parametrize( + "field", + [ + methodcaller("year"), + pytest.param( + methodcaller("day"), + marks=pytest.mark.xfail(reason="day extraction not implemented"), + ), + ], +) +def test_datetime_extract(field): + ldf = pl.LazyFrame( + {"dates": [datetime.date(2024, 1, 1), datetime.date(2024, 10, 11)]} + ) + q = ldf.select(field(pl.col("dates").dt)) + + with pytest.raises(AssertionError): + # polars produces int32, libcudf produces int16 for the year extraction + # libcudf can lose data here. + # https://github.com/rapidsai/cudf/issues/16196 + assert_gpu_result_equal(q) + + assert_gpu_result_equal(q, check_dtypes=False) diff --git a/python/cudf_polars/tests/expressions/test_round.py b/python/cudf_polars/tests/expressions/test_round.py new file mode 100644 index 00000000000..3af3a0ce6d1 --- /dev/null +++ b/python/cudf_polars/tests/expressions/test_round.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import math + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.fixture(params=[pl.Float32, pl.Float64]) +def dtype(request): + return request.param + + +@pytest.fixture +def df(dtype, with_nulls): + a = [-math.e, 10, 22.5, 1.5, 2.5, -1.5, math.pi, 8] + if with_nulls: + a[2] = None + a[-1] = None + return pl.LazyFrame({"a": a}, schema={"a": dtype}) + + +@pytest.mark.parametrize("decimals", [0, 2, 4]) +def test_round(df, decimals): + q = df.select(pl.col("a").round(decimals=decimals)) + + assert_gpu_result_equal(q, check_exact=False) diff --git a/python/cudf_polars/tests/expressions/test_unique.py b/python/cudf_polars/tests/expressions/test_unique.py new file mode 100644 index 00000000000..9b009a422c2 --- /dev/null +++ b/python/cudf_polars/tests/expressions/test_unique.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.mark.parametrize("maintain_order", [False, True], ids=["unstable", "stable"]) +@pytest.mark.parametrize("pre_sorted", [False, True], ids=["unsorted", "sorted"]) +def test_unique(maintain_order, pre_sorted): + ldf = pl.DataFrame( + { + "b": [1.5, 2.5, None, 1.5, 3, float("nan"), 3], + } + ).lazy() + if pre_sorted: + ldf = ldf.sort("b") + + query = ldf.select(pl.col("b").unique(maintain_order=maintain_order)) + assert_gpu_result_equal(query, check_row_order=maintain_order) diff --git a/python/cudf_polars/tests/test_groupby.py b/python/cudf_polars/tests/test_groupby.py index 8a6732b7063..b84e2c16b43 100644 --- a/python/cudf_polars/tests/test_groupby.py +++ b/python/cudf_polars/tests/test_groupby.py @@ -47,6 +47,8 @@ def keys(request): [pl.col("float").max() - pl.col("int").min()], [pl.col("float").mean(), pl.col("int").std()], [(pl.col("float") - pl.lit(2)).max()], + [pl.col("float").sum().round(decimals=1)], + [pl.col("float").round(decimals=1).sum()], ], ids=lambda aggs: "-".join(map(str, aggs)), )