diff --git a/python/cudf_polars/cudf_polars/expressions.py b/python/cudf_polars/cudf_polars/expressions.py index 5423cb20345..144461d80e0 100644 --- a/python/cudf_polars/cudf_polars/expressions.py +++ b/python/cudf_polars/cudf_polars/expressions.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections import defaultdict +from enum import IntEnum, auto from functools import singledispatch from typing import TYPE_CHECKING @@ -35,16 +36,58 @@ from cudf_polars.typing import ColumnType, Expr, Visitor +class ExecutionContext(IntEnum): + """Tag for the current execution context.""" + + GROUPBY = auto() + "Executing inside a group_by expression." + ROLLING = auto() + "Executing inside a rolling expression." + DATAFRAME = auto() + "Executing on the whole dataframe." + + class ExprVisitor: """Object holding rust visitor and utility methods.""" - __slots__ = ("visitor", "in_groupby") + __slots__ = ("visitor", "context", "node_stack") visitor: Visitor - in_groupby: bool + context: ExecutionContext + node_stack: list[int] + + class _with_context: + def __init__(self, context: ExecutionContext, visitor: ExprVisitor): + self.context = context + self.visitor = visitor + + def __enter__(self): + self.visitor.context, self.context = ( + self.context, + self.visitor.context, + ) + + def __exit__(self, *args): + self.visitor.context = self.context def __init__(self, visitor: Visitor): self.visitor = visitor - self.in_groupby = False + self.context = ExecutionContext.DATAFRAME + self.node_stack = [] + + def with_context(self, context: ExecutionContext): + """ + Context manager for setting the execution context of the visitor. + + Parameters + ---------- + context + New execution context + + Returns + ------- + context manager that sets and restores the execution context. + """ + return self._with_context(context, self) def add_expressions( self, expressions: Sequence[Expr] @@ -94,7 +137,17 @@ def __call__(self, node: int, context: DataFrame) -> ColumnType: ------- New column as the evaluation of the expression. """ - return evaluate_expr(self.visitor.view_expression(node), context, self) + self.node_stack.append(node) + result = evaluate_expr( + self.visitor.view_expression(node), context, self + ) + self.node_stack.pop() + return result + + @property + def dtype(self): + """Return the datatype of the current expression node.""" + return self.visitor.get_dtype(self.node_stack[-1]) @singledispatch @@ -310,11 +363,6 @@ def _expr_function( # TODO: tracking sortedness (column,) = arguments return column - # (name,) = data.keys() - # (flag,) = fargs - # return data.set_sorted( - # {name: getattr(DataFrame.IsSorted, flag.upper())} - # ) elif fname in BOOLEAN_FUNCTIONS: return boolean_function(fname, arguments, fargs) else: @@ -365,33 +413,26 @@ def _literal( @evaluate_expr.register def _sort(expr: expr_nodes.Sort, context: DataFrame, visitor: ExprVisitor): - if visitor.in_groupby: - raise NotImplementedError("sort inside groupby") + if visitor.context is not ExecutionContext.DATAFRAME: + raise NotImplementedError("sort inside groupby/rolling") to_sort = visitor(expr.expr, context) (stable, nulls_last, descending) = expr.options descending, column_order, null_precedence = sort_order( [descending], nulls_last=nulls_last, num_keys=1 ) do_sort = plc.sorting.stable_sort if stable else plc.sorting.sort - result = do_sort(to_sort.to_pylibcudf(), column_order, null_precedence) + (result,) = do_sort( + plc.Table([to_sort]), column_order, null_precedence + ).columns() return result - # TODO: track sortedness - # flag = ( - # DataFrame.IsSorted.DESCENDING - # if descending - # else DataFrame.IsSorted.ASCENDING - # ) - # return DataFrame.from_pylibcudf(to_sort.names(), result).set_sorted( - # {name: flag} - # ) @evaluate_expr.register def _sort_by( expr: expr_nodes.SortBy, context: DataFrame, visitor: ExprVisitor ): - if visitor.in_groupby: - raise NotImplementedError("sort_by inside groupby") + if visitor.context is not ExecutionContext.DATAFRAME: + raise NotImplementedError("sort_by inside groupby/rolling") to_sort = visitor(expr.expr, context) descending = expr.descending sort_keys = [visitor(e, context) for e in expr.by] @@ -399,12 +440,13 @@ def _sort_by( descending, column_order, null_precedence = sort_order( descending, nulls_last=True, num_keys=len(sort_keys) ) - return plc.sorting.sort_by_key( + (result,) = plc.sorting.sort_by_key( plc.Table([to_sort]), plc.Table(sort_keys), column_order, null_precedence, ) + return result @evaluate_expr.register @@ -432,8 +474,8 @@ def _gather(expr: expr_nodes.Gather, context: DataFrame, visitor: ExprVisitor): @evaluate_expr.register def _filter(expr: expr_nodes.Filter, context: DataFrame, visitor: ExprVisitor): - if visitor.in_groupby: - raise NotImplementedError("filter inside groupby") + if visitor.context is not ExecutionContext.DATAFRAME: + raise NotImplementedError("filter inside groupby/rolling") result = visitor(expr.input, context) mask = visitor(expr.by, context) (column,) = plc.stream_compaction.apply_boolean_mask( @@ -458,8 +500,8 @@ def _column(expr: expr_nodes.Column, context: DataFrame, visitor: ExprVisitor): @evaluate_expr.register def _agg(expr: expr_nodes.Agg, context: DataFrame, visitor: ExprVisitor): - if visitor.in_groupby: - raise NotImplementedError("nested agg in group_by") + if visitor.context is not ExecutionContext.DATAFRAME: + raise NotImplementedError("nested agg in groupby/rolling") name = expr.name column = visitor(expr.arguments, context) # TODO: handle options @@ -711,9 +753,8 @@ def collect_agg( return [*lcol, *rcol], [*lreq, *rreq] else: # TODO: Ugly non-local method of saying "we're in a groupby, disallow" - visitor.in_groupby = True - column = evaluate_expr(agg, context, visitor) - visitor.in_groupby = False + with visitor.with_context(ExecutionContext.GROUPBY): + column = evaluate_expr(agg, context, visitor) return [column], [(plc.aggregation.collect_list(), node)] elif isinstance(agg, expr_nodes.Literal): # Scalar value, constant across the groups diff --git a/python/cudf_polars/cudf_polars/plan.py b/python/cudf_polars/cudf_polars/plan.py index 2a70390ad95..30794c7b653 100644 --- a/python/cudf_polars/cudf_polars/plan.py +++ b/python/cudf_polars/cudf_polars/plan.py @@ -121,14 +121,17 @@ def __call__(self, n: int | None = None) -> DataFrame: Node to evaluate (optional), if not provided uses the internal visitor's "current" node. + Notes + ----- + Side-effectfully modifies the visitor to set the current node. + Returns ------- New dataframe giving the evaluation of the plan. """ - if n is None: - node = self.visitor.view_current_node() - else: - node = self.visitor.view_node(n) + if n is not None: + self.visitor.set_node(n) + node = self.visitor.view_current_node() return _execute_plan(node, self) def record(self, name: str): @@ -200,8 +203,6 @@ def _python_scan(plan: nodes.PythonScan, visitor: PlanVisitor): with visitor.record("PythonScan"): ( scan_fn, - schema, - output_schema, with_columns, is_pyarrow, predicate, @@ -211,6 +212,8 @@ def _python_scan(plan: nodes.PythonScan, visitor: PlanVisitor): if is_pyarrow: raise NotImplementedError("Don't know what to do here") context = scan_fn(with_columns, predicate, nrows) + if not isinstance(context, DataFrame): + raise TypeError(f"Don't know how to handle a {type(context)}") if predicate is not None: mask = visitor.expr_visitor(predicate.node, context) return context.filter(mask) @@ -227,7 +230,7 @@ def _scan(plan: nodes.Scan, visitor: PlanVisitor): n_rows = options.n_rows with_columns = options.with_columns row_index = options.row_index - schema = plan.output_schema + schema = visitor.visitor.get_schema() # TODO: Send all the options through to the libcudf readers where appropriate if n_rows is not None: # TODO: read_csv supports n_rows, but if we have more than one @@ -422,7 +425,7 @@ def _join(plan: nodes.Join, visitor: PlanVisitor): right_on = plc.Table( [visitor.expr_visitor(e.node, right) for e in plan.right_on] ) - how, join_nulls, zlice, suffix = plan.options + how, join_nulls, zlice, suffix, coalesce_key_columns = plan.options null_equality = ( plc.types.NullEquality.EQUAL if join_nulls @@ -431,9 +434,7 @@ def _join(plan: nodes.Join, visitor: PlanVisitor): suffix = "_right" if suffix is None else suffix if how == "cross": raise NotImplementedError("cross join not implemented") - coalesce_key_columns = True - if how == "outer": - coalesce_key_columns = False + if how == "outer" and not coalesce_key_columns: raise NotImplementedError("Non-coalescing outer join") elif how == "outer_coalesce": how = "outer" @@ -576,7 +577,8 @@ def _sort(plan: nodes.Sort, visitor: PlanVisitor): sort_keys = [ visitor.expr_visitor(e.node, result) for e in plan.by_column ] - (stable, nulls_last, descending, zlice) = plan.args + (descending, nulls_last, stable) = plan.sort_options + zlice = plan.slice descending, column_order, null_precedence = sort_order( descending, nulls_last=nulls_last, num_keys=len(sort_keys) ) @@ -632,8 +634,8 @@ def _filter(plan: nodes.Filter, visitor: PlanVisitor): @_execute_plan.register def _simple_projection(plan: nodes.SimpleProjection, visitor: PlanVisitor): + schema = visitor.visitor.get_schema() result = visitor(plan.input) - schema = plan.columns with visitor.record("simple_projection"): return DataFrame({name: result[name] for name in schema}) @@ -708,7 +710,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor): elif typ == "explode": context = visitor(plan.input) with profiler: - column_names, schema = args + (column_names,) = args if len(column_names) > 1: # TODO: straightforward, but need to error check # polars requires that all to-explode columns have the diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index 37b2936c6ee..19e715d04a5 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -114,6 +114,10 @@ ignore = [ "W191", ] +[tool.ruff.lint.per-file-ignores] +# No need for docstrings on tests +"tests/**.py" = ["D"] + [tool.ruff.lint.pycodestyle] max-doc-length = 85 diff --git a/python/cudf_polars/tests/test_basic.py b/python/cudf_polars/tests/test_basic.py index 71eef925179..b8468258fd3 100644 --- a/python/cudf_polars/tests/test_basic.py +++ b/python/cudf_polars/tests/test_basic.py @@ -7,14 +7,8 @@ import numpy as np import polars as pl import pytest -from polars.testing.asserts import assert_frame_equal -from cudf_polars.patch import _WAS_PATCHED - -if not _WAS_PATCHED: - # We could also just patch in the test, but this approach provides a canary for - # failures with patching that we might observe in trying this with other tests. - raise RuntimeError("Patch was not applied") +from cudf_polars.testing.asserts import assert_gpu_result_equal @pytest.fixture() @@ -54,12 +48,6 @@ def ldf(df): return df.lazy() -def assert_gpu_result_equal(lazydf, **kwargs): - expect = lazydf.collect(use_gpu=False) - got = lazydf.collect(use_gpu=True, cpu_fallback=False) - assert_frame_equal(expect, got, **kwargs) - - @pytest.mark.parametrize("dtype", ["int32", "int64", "float32", "float64"]) @pytest.mark.parametrize( "op", [operator.add, operator.sub, operator.mul, operator.truediv] diff --git a/python/cudf_polars/tests/test_distinct.py b/python/cudf_polars/tests/test_distinct.py new file mode 100644 index 00000000000..533886ed600 --- /dev/null +++ b/python/cudf_polars/tests/test_distinct.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +import polars as pl +import pytest + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.mark.parametrize("subset", [None, ["a"], ["a", "b"], ["b", "c"]]) +@pytest.mark.parametrize("keep", ["any", "none", "first", "last"]) +@pytest.mark.parametrize( + "maintain_order", [False, True], ids=["unstable", "stable"] +) +def test_distinct(subset, keep, maintain_order): + ldf = pl.DataFrame( + { + "a": [1, 2, 1, 3, 5, None, None], + "b": [1.5, 2.5, None, 1.5, 3, float("nan"), 3], + "c": [True, True, True, True, False, False, True], + } + ).lazy() + + query = ldf.unique(subset=subset, keep=keep, maintain_order=maintain_order) + assert_gpu_result_equal(query, check_row_order=maintain_order) diff --git a/python/cudf_polars/tests/test_sort.py b/python/cudf_polars/tests/test_sort.py new file mode 100644 index 00000000000..1051b9f80c6 --- /dev/null +++ b/python/cudf_polars/tests/test_sort.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +import polars as pl +import pytest + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.mark.parametrize( + "sort_keys", + [ + (pl.col("a"),), + pytest.param( + (pl.col("d").abs(),), + marks=pytest.mark.xfail(reason="abs not yet implemented"), + ), + (pl.col("a"), pl.col("d")), + (pl.col("b"),), + ], +) +@pytest.mark.parametrize("nulls_last", [False, True]) +@pytest.mark.parametrize( + "maintain_order", [False, True], ids=["unstable", "stable"] +) +def test_distinct(sort_keys, nulls_last, maintain_order): + ldf = pl.DataFrame( + { + "a": [1, 2, 1, 3, 5, None, None], + "b": [1.5, 2.5, None, 1.5, 3, float("nan"), 3], + "c": [True, True, True, True, False, False, True], + "d": [1, 2, -1, 10, 6, -1, -7], + } + ).lazy() + + query = ldf.sort( + *sort_keys, + descending=True, + nulls_last=nulls_last, + maintain_order=maintain_order, + ) + assert_gpu_result_equal(query, check_row_order=maintain_order)