From c7b28ceeb46d2b921e30f081a9ed97745c91ff9e Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Tue, 23 Jul 2024 05:28:13 -0500 Subject: [PATCH] Add `drop_nulls` in `cudf-polars` (#16290) Closes https://github.com/rapidsai/cudf/issues/16219 Authors: - https://github.com/brandon-b-miller Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/16290 --- python/cudf_polars/cudf_polars/dsl/expr.py | 30 +++++++++- python/cudf_polars/tests/test_drop_nulls.py | 65 +++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 python/cudf_polars/tests/test_drop_nulls.py diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index a034d55120a..8322d6bd6fb 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -882,7 +882,14 @@ def __init__( self.name = name self.options = options self.children = children - if self.name not in ("mask_nans", "round", "setsorted", "unique"): + if self.name not in ( + "mask_nans", + "round", + "setsorted", + "unique", + "dropnull", + "fill_null", + ): raise NotImplementedError(f"Unary function {name=}") def do_evaluate( @@ -968,6 +975,27 @@ def do_evaluate( order=order, null_order=null_order, ) + elif self.name == "dropnull": + (column,) = ( + child.evaluate(df, context=context, mapping=mapping) + for child in self.children + ) + return Column( + plc.stream_compaction.drop_nulls( + plc.Table([column.obj]), [0], 1 + ).columns()[0] + ) + elif self.name == "fill_null": + column = self.children[0].evaluate(df, context=context, mapping=mapping) + if isinstance(self.children[1], Literal): + arg = plc.interop.from_arrow(self.children[1].value) + else: + evaluated = self.children[1].evaluate( + df, context=context, mapping=mapping + ) + arg = evaluated.obj_scalar if evaluated.is_scalar else evaluated.obj + return Column(plc.replace.replace_nulls(column.obj, arg)) + raise NotImplementedError( f"Unimplemented unary function {self.name=}" ) # pragma: no cover; init trips first diff --git a/python/cudf_polars/tests/test_drop_nulls.py b/python/cudf_polars/tests/test_drop_nulls.py new file mode 100644 index 00000000000..5dfe9f66a97 --- /dev/null +++ b/python/cudf_polars/tests/test_drop_nulls.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) + + +@pytest.fixture( + params=[ + [1, 2, 1, 3, 5, None, None], + [1.5, 2.5, None, 1.5, 3, float("nan"), 3], + [], + [None, None], + [1, 2, 3, 4, 5], + ] +) +def null_data(request): + is_empty = pl.Series(request.param).dtype == pl.Null + return pl.DataFrame( + { + "a": pl.Series(request.param, dtype=pl.Float64 if is_empty else None), + "b": pl.Series(request.param, dtype=pl.Float64 if is_empty else None), + } + ).lazy() + + +def test_drop_null(null_data): + q = null_data.select(pl.col("a").drop_nulls()) + assert_gpu_result_equal(q) + + +@pytest.mark.parametrize( + "value", + [0, pl.col("a").mean(), pl.col("b")], + ids=["scalar", "aggregation", "column_expression"], +) +def test_fill_null(null_data, value): + q = null_data.select(pl.col("a").fill_null(value)) + assert_gpu_result_equal(q) + + +@pytest.mark.parametrize( + "strategy", ["forward", "backward", "min", "max", "mean", "zero", "one"] +) +def test_fill_null_with_strategy(null_data, strategy): + q = null_data.select(pl.col("a").fill_null(strategy=strategy)) + + # Not yet exposed to python from rust + assert_ir_translation_raises(q, NotImplementedError) + + +@pytest.mark.parametrize("strategy", ["forward", "backward"]) +@pytest.mark.parametrize("limit", [0, 1, 2]) +def test_fill_null_with_limit(null_data, strategy, limit): + q = null_data.select(pl.col("a").fill_null(strategy=strategy, limit=limit)) + + # Not yet exposed to python from rust + assert_ir_translation_raises(q, NotImplementedError)