Skip to content

Commit

Permalink
Handle nans for nan-ignoring aggs in groupby-agg
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Jul 9, 2024
1 parent f85a899 commit 587e373
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 12 deletions.
12 changes: 11 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def __init__(
self.name = name
self.options = options
self.children = children
if self.name not in ("round", "unique"):
if self.name not in ("round", "unique", "mask_nans"):
raise NotImplementedError(f"Unary function {name=}")

def do_evaluate(
Expand All @@ -878,6 +878,9 @@ def do_evaluate(
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name == "mask_nans":
(child,) = self.children
return child.evaluate(df, context=context, mapping=mapping).mask_nans()
if self.name == "round":
(decimal_places,) = self.options
(values,) = (
Expand Down Expand Up @@ -1215,12 +1218,19 @@ def collect_agg(self, *, depth: int) -> AggInfo:
raise NotImplementedError(
"Nested aggregations in groupby"
) # pragma: no cover; check_agg trips first
if (isminmax := self.name in {"min", "max"}) and self.options:
raise NotImplementedError("Nan propagation in groupby for min/max")
(child,) = self.children
((expr, _, _),) = child.collect_agg(depth=depth + 1).requests
if self.request is None:
raise NotImplementedError(
f"Aggregation {self.name} in groupby"
) # pragma: no cover; __init__ trips first
if isminmax and plc.traits.is_floating_point(self.dtype):
assert expr is not None
# Ignore nans in these groupby aggs, do this by masking
# nans in the input
expr = UnaryFunction(self.dtype, "mask_nans", (), expr)
return AggInfo([(expr, self.request, self)])

def _reduce(
Expand Down
11 changes: 7 additions & 4 deletions python/cudf_polars/tests/containers/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ def test_mask_nans(typeid, constructor):
values = pyarrow.array([0, 0, 0], type=plc.interop.to_arrow(dtype))
column = constructor(plc.interop.from_arrow(values))
masked = column.mask_nans()
assert column.obj is masked.obj
assert column.obj.null_count() == masked.obj.null_count()


def test_mask_nans_float_with_nan_notimplemented():
def test_mask_nans_float():
dtype = plc.DataType(plc.TypeId.FLOAT32)
values = pyarrow.array([0, 0, float("nan")], type=plc.interop.to_arrow(dtype))
column = Column(plc.interop.from_arrow(values))
with pytest.raises(NotImplementedError):
_ = column.mask_nans()
masked = column.mask_nans()
expect = pyarrow.array([0, 0, None], type=plc.interop.to_arrow(dtype))
got = pyarrow.array(plc.interop.to_arrow(masked.obj))

assert expect == got
25 changes: 18 additions & 7 deletions python/cudf_polars/tests/expressions/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,25 @@ def test_agg(df, agg):


@pytest.mark.parametrize(
"propagate_nans",
[pytest.param(False, marks=pytest.mark.xfail(reason="Need to mask nans")), True],
ids=["mask_nans", "propagate_nans"],
"op", [pl.Expr.min, pl.Expr.nan_min, pl.Expr.max, pl.Expr.nan_max]
)
@pytest.mark.parametrize("op", ["min", "max"])
def test_agg_float_with_nans(propagate_nans, op):
df = pl.LazyFrame({"a": pl.Series([1, 2, float("nan")], dtype=pl.Float64())})
op = getattr(pl.Expr, f"nan_{op}" if propagate_nans else op)
def test_agg_float_with_nans(op):
df = pl.LazyFrame(
{
"a": pl.Series([1, 2, float("nan")], dtype=pl.Float64()),
"b": pl.Series([1, 2, None], dtype=pl.Int8()),
}
)
q = df.select(op(pl.col("a")), op(pl.col("b")))

assert_gpu_result_equal(q)


@pytest.mark.xfail(reason="https://github.com/pola-rs/polars/issues/17513")
@pytest.mark.parametrize("op", [pl.Expr.max, pl.Expr.min])
def test_agg_singleton(op):
df = pl.LazyFrame({"a": pl.Series([float("nan")])})

q = df.select(op(pl.col("a")))

assert_gpu_result_equal(q)
24 changes: 24 additions & 0 deletions python/cudf_polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,27 @@ def test_groupby_unsupported(df, expr):
q = df.group_by("key1").agg(expr)

assert_ir_translation_raises(q, NotImplementedError)


@pytest.mark.xfail(reason="https://github.com/pola-rs/polars/issues/17513")
def test_groupby_minmax_with_nan():
df = pl.LazyFrame(
{"key": [1, 2, 2, 2], "value": [float("nan"), 1, -1, float("nan")]}
)

q = df.group_by("key").agg(
pl.col("value").max().alias("max"), pl.col("value").min().alias("min")
)

assert_gpu_result_equal(q)


@pytest.mark.parametrize("op", [pl.Expr.nan_max, pl.Expr.nan_min])
def test_groupby_nan_minmax_raises(op):
df = pl.LazyFrame(
{"key": [1, 2, 2, 2], "value": [float("nan"), 1, -1, float("nan")]}
)

q = df.group_by("key").agg(op(pl.col("value")))

assert_ir_translation_raises(q, NotImplementedError)

0 comments on commit 587e373

Please sign in to comment.