Skip to content

Commit

Permalink
Add drop_nulls in cudf-polars (#16290)
Browse files Browse the repository at this point in the history
Closes #16219

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #16290
  • Loading branch information
brandon-b-miller authored Jul 23, 2024
1 parent 0cac2a9 commit c7b28ce
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
30 changes: 29 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,14 @@ def __init__(
self.name = name
self.options = options
self.children = children
if self.name not in ("mask_nans", "round", "setsorted", "unique"):
if self.name not in (
"mask_nans",
"round",
"setsorted",
"unique",
"dropnull",
"fill_null",
):
raise NotImplementedError(f"Unary function {name=}")

def do_evaluate(
Expand Down Expand Up @@ -968,6 +975,27 @@ def do_evaluate(
order=order,
null_order=null_order,
)
elif self.name == "dropnull":
(column,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
return Column(
plc.stream_compaction.drop_nulls(
plc.Table([column.obj]), [0], 1
).columns()[0]
)
elif self.name == "fill_null":
column = self.children[0].evaluate(df, context=context, mapping=mapping)
if isinstance(self.children[1], Literal):
arg = plc.interop.from_arrow(self.children[1].value)
else:
evaluated = self.children[1].evaluate(
df, context=context, mapping=mapping
)
arg = evaluated.obj_scalar if evaluated.is_scalar else evaluated.obj
return Column(plc.replace.replace_nulls(column.obj, arg))

raise NotImplementedError(
f"Unimplemented unary function {self.name=}"
) # pragma: no cover; init trips first
Expand Down
65 changes: 65 additions & 0 deletions python/cudf_polars/tests/test_drop_nulls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import (
assert_gpu_result_equal,
assert_ir_translation_raises,
)


@pytest.fixture(
params=[
[1, 2, 1, 3, 5, None, None],
[1.5, 2.5, None, 1.5, 3, float("nan"), 3],
[],
[None, None],
[1, 2, 3, 4, 5],
]
)
def null_data(request):
is_empty = pl.Series(request.param).dtype == pl.Null
return pl.DataFrame(
{
"a": pl.Series(request.param, dtype=pl.Float64 if is_empty else None),
"b": pl.Series(request.param, dtype=pl.Float64 if is_empty else None),
}
).lazy()


def test_drop_null(null_data):
q = null_data.select(pl.col("a").drop_nulls())
assert_gpu_result_equal(q)


@pytest.mark.parametrize(
"value",
[0, pl.col("a").mean(), pl.col("b")],
ids=["scalar", "aggregation", "column_expression"],
)
def test_fill_null(null_data, value):
q = null_data.select(pl.col("a").fill_null(value))
assert_gpu_result_equal(q)


@pytest.mark.parametrize(
"strategy", ["forward", "backward", "min", "max", "mean", "zero", "one"]
)
def test_fill_null_with_strategy(null_data, strategy):
q = null_data.select(pl.col("a").fill_null(strategy=strategy))

# Not yet exposed to python from rust
assert_ir_translation_raises(q, NotImplementedError)


@pytest.mark.parametrize("strategy", ["forward", "backward"])
@pytest.mark.parametrize("limit", [0, 1, 2])
def test_fill_null_with_limit(null_data, strategy, limit):
q = null_data.select(pl.col("a").fill_null(strategy=strategy, limit=limit))

# Not yet exposed to python from rust
assert_ir_translation_raises(q, NotImplementedError)

0 comments on commit c7b28ce

Please sign in to comment.