From e2c34f64cedb28f81503fcdf1ffd77eca576d273 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 12 Jun 2024 13:50:06 +0000 Subject: [PATCH] Add full coverage for whole-frame Agg expressions Also add more expansive comments on the unreachable paths. --- python/cudf_polars/cudf_polars/dsl/expr.py | 58 ++++++++----------- .../cudf_polars/tests/expressions/test_agg.py | 14 +++++ 2 files changed, 38 insertions(+), 34 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index a81cdcbf0c3..8edbe840d51 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -899,7 +899,9 @@ def __init__( self.options = options self.children = (value,) if name not in Agg._SUPPORTED: - raise NotImplementedError(f"Unsupported aggregation {name=}") + raise NotImplementedError( + f"Unsupported aggregation {name=}" + ) # pragma: no cover; all valid aggs are supported # TODO: nan handling in groupby case if name == "min": req = plc.aggregation.min() @@ -925,7 +927,9 @@ def __init__( elif name == "count": req = plc.aggregation.count(null_handling=plc.types.NullPolicy.EXCLUDE) else: - raise NotImplementedError + raise NotImplementedError( + f"Unreachable, {name=} is incorrectly listed in _SUPPORTED" + ) # pragma: no cover self.request = req op = getattr(self, f"_{name}", None) if op is None: @@ -935,7 +939,9 @@ def __init__( elif name in {"count", "first", "last"}: pass else: - raise AssertionError + raise NotImplementedError( + f"Unreachable, supported agg {name=} has no implementation" + ) # pragma: no cover self.op = op _SUPPORTED: ClassVar[frozenset[str]] = frozenset( @@ -957,11 +963,15 @@ def __init__( def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" if depth >= 1: - raise NotImplementedError("Nested aggregations in groupby") + raise NotImplementedError( + "Nested aggregations in groupby" + ) # pragma: no cover; check_agg trips first (child,) = self.children ((expr, _, _),) = child.collect_agg(depth=depth + 1).requests if self.request is None: - raise NotImplementedError(f"Aggregation {self.name} in groupby") + raise NotImplementedError( + f"Aggregation {self.name} in groupby" + ) # pragma: no cover; __init__ trips first return AggInfo([(expr, self.request, self)]) def _reduce( @@ -971,10 +981,7 @@ def _reduce( plc.Column.from_scalar( plc.reduce.reduce(column.obj, request, self.dtype), 1, - ), - is_sorted=plc.types.Sorted.YES, - order=plc.types.Order.ASCENDING, - null_order=plc.types.NullOrder.BEFORE, + ) ) def _count(self, column: Column) -> Column: @@ -987,10 +994,7 @@ def _count(self, column: Column) -> Column: ), ), 1, - ), - is_sorted=plc.types.Sorted.YES, - order=plc.types.Order.ASCENDING, - null_order=plc.types.NullOrder.BEFORE, + ) ) def _min(self, column: Column, *, propagate_nans: bool) -> Column: @@ -1001,10 +1005,7 @@ def _min(self, column: Column, *, propagate_nans: bool) -> Column: pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype)) ), 1, - ), - is_sorted=plc.types.Sorted.YES, - order=plc.types.Order.ASCENDING, - null_order=plc.types.NullOrder.BEFORE, + ) ) if column.nan_count > 0: column = column.mask_nans() @@ -1018,31 +1019,18 @@ def _max(self, column: Column, *, propagate_nans: bool) -> Column: pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype)) ), 1, - ), - is_sorted=plc.types.Sorted.YES, - order=plc.types.Order.ASCENDING, - null_order=plc.types.NullOrder.BEFORE, + ) ) if column.nan_count > 0: column = column.mask_nans() return self._reduce(column, request=plc.aggregation.max()) def _first(self, column: Column) -> Column: - return Column( - plc.copying.slice(column.obj, [0, 1])[0], - is_sorted=plc.types.Sorted.YES, - order=plc.types.Order.ASCENDING, - null_order=plc.types.NullOrder.BEFORE, - ) + return Column(plc.copying.slice(column.obj, [0, 1])[0]) def _last(self, column: Column) -> Column: n = column.obj.size() - return Column( - plc.copying.slice(column.obj, [n - 1, n])[0], - is_sorted=plc.types.Sorted.YES, - order=plc.types.Order.ASCENDING, - null_order=plc.types.NullOrder.BEFORE, - ) + return Column(plc.copying.slice(column.obj, [n - 1, n])[0]) def do_evaluate( self, @@ -1053,7 +1041,9 @@ def do_evaluate( ) -> Column: """Evaluate this expression given a dataframe for context.""" if context is not ExecutionContext.FRAME: - raise NotImplementedError(f"Agg in context {context}") + raise NotImplementedError( + f"Agg in context {context}" + ) # pragma: no cover; unreachable (child,) = self.children return self.op(child.evaluate(df, context=context, mapping=mapping)) diff --git a/python/cudf_polars/tests/expressions/test_agg.py b/python/cudf_polars/tests/expressions/test_agg.py index 79018c80bf3..0f8c103ffc7 100644 --- a/python/cudf_polars/tests/expressions/test_agg.py +++ b/python/cudf_polars/tests/expressions/test_agg.py @@ -61,3 +61,17 @@ def test_agg(df, agg): with pytest.raises(AssertionError): assert_gpu_result_equal(q) assert_gpu_result_equal(q, check_dtypes=check_dtypes, check_exact=False) + + +@pytest.mark.parametrize( + "propagate_nans", + [pytest.param(False, marks=pytest.mark.xfail(reason="Need to mask nans")), True], + ids=["mask_nans", "propagate_nans"], +) +@pytest.mark.parametrize("op", ["min", "max"]) +def test_agg_float_with_nans(propagate_nans, op): + df = pl.LazyFrame({"a": [1, 2, float("nan")]}) + op = getattr(pl.Expr, f"nan_{op}" if propagate_nans else op) + q = df.select(op(pl.col("a"))) + + assert_gpu_result_equal(q)