Skip to content

Commit

Permalink
Add full coverage for whole-frame Agg expressions
Browse files Browse the repository at this point in the history
Also add more expansive comments on the unreachable paths.
  • Loading branch information
wence- committed Jun 12, 2024
1 parent f7ba6ab commit e2c34f6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 34 deletions.
58 changes: 24 additions & 34 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,9 @@ def __init__(
self.options = options
self.children = (value,)
if name not in Agg._SUPPORTED:
raise NotImplementedError(f"Unsupported aggregation {name=}")
raise NotImplementedError(
f"Unsupported aggregation {name=}"
) # pragma: no cover; all valid aggs are supported
# TODO: nan handling in groupby case
if name == "min":
req = plc.aggregation.min()
Expand All @@ -925,7 +927,9 @@ def __init__(
elif name == "count":
req = plc.aggregation.count(null_handling=plc.types.NullPolicy.EXCLUDE)
else:
raise NotImplementedError
raise NotImplementedError(
f"Unreachable, {name=} is incorrectly listed in _SUPPORTED"
) # pragma: no cover
self.request = req
op = getattr(self, f"_{name}", None)
if op is None:
Expand All @@ -935,7 +939,9 @@ def __init__(
elif name in {"count", "first", "last"}:
pass
else:
raise AssertionError
raise NotImplementedError(
f"Unreachable, supported agg {name=} has no implementation"
) # pragma: no cover
self.op = op

_SUPPORTED: ClassVar[frozenset[str]] = frozenset(
Expand All @@ -957,11 +963,15 @@ def __init__(
def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
if depth >= 1:
raise NotImplementedError("Nested aggregations in groupby")
raise NotImplementedError(
"Nested aggregations in groupby"
) # pragma: no cover; check_agg trips first
(child,) = self.children
((expr, _, _),) = child.collect_agg(depth=depth + 1).requests
if self.request is None:
raise NotImplementedError(f"Aggregation {self.name} in groupby")
raise NotImplementedError(
f"Aggregation {self.name} in groupby"
) # pragma: no cover; __init__ trips first
return AggInfo([(expr, self.request, self)])

def _reduce(
Expand All @@ -971,10 +981,7 @@ def _reduce(
plc.Column.from_scalar(
plc.reduce.reduce(column.obj, request, self.dtype),
1,
),
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
null_order=plc.types.NullOrder.BEFORE,
)
)

def _count(self, column: Column) -> Column:
Expand All @@ -987,10 +994,7 @@ def _count(self, column: Column) -> Column:
),
),
1,
),
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
null_order=plc.types.NullOrder.BEFORE,
)
)

def _min(self, column: Column, *, propagate_nans: bool) -> Column:
Expand All @@ -1001,10 +1005,7 @@ def _min(self, column: Column, *, propagate_nans: bool) -> Column:
pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
),
1,
),
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
null_order=plc.types.NullOrder.BEFORE,
)
)
if column.nan_count > 0:
column = column.mask_nans()
Expand All @@ -1018,31 +1019,18 @@ def _max(self, column: Column, *, propagate_nans: bool) -> Column:
pa.scalar(float("nan"), type=plc.interop.to_arrow(self.dtype))
),
1,
),
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
null_order=plc.types.NullOrder.BEFORE,
)
)
if column.nan_count > 0:
column = column.mask_nans()
return self._reduce(column, request=plc.aggregation.max())

def _first(self, column: Column) -> Column:
return Column(
plc.copying.slice(column.obj, [0, 1])[0],
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
null_order=plc.types.NullOrder.BEFORE,
)
return Column(plc.copying.slice(column.obj, [0, 1])[0])

def _last(self, column: Column) -> Column:
n = column.obj.size()
return Column(
plc.copying.slice(column.obj, [n - 1, n])[0],
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
null_order=plc.types.NullOrder.BEFORE,
)
return Column(plc.copying.slice(column.obj, [n - 1, n])[0])

def do_evaluate(
self,
Expand All @@ -1053,7 +1041,9 @@ def do_evaluate(
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if context is not ExecutionContext.FRAME:
raise NotImplementedError(f"Agg in context {context}")
raise NotImplementedError(
f"Agg in context {context}"
) # pragma: no cover; unreachable
(child,) = self.children
return self.op(child.evaluate(df, context=context, mapping=mapping))

Expand Down
14 changes: 14 additions & 0 deletions python/cudf_polars/tests/expressions/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,17 @@ def test_agg(df, agg):
with pytest.raises(AssertionError):
assert_gpu_result_equal(q)
assert_gpu_result_equal(q, check_dtypes=check_dtypes, check_exact=False)


@pytest.mark.parametrize(
"propagate_nans",
[pytest.param(False, marks=pytest.mark.xfail(reason="Need to mask nans")), True],
ids=["mask_nans", "propagate_nans"],
)
@pytest.mark.parametrize("op", ["min", "max"])
def test_agg_float_with_nans(propagate_nans, op):
df = pl.LazyFrame({"a": [1, 2, float("nan")]})
op = getattr(pl.Expr, f"nan_{op}" if propagate_nans else op)
q = df.select(op(pl.col("a")))

assert_gpu_result_equal(q)

0 comments on commit e2c34f6

Please sign in to comment.