From cb2a5a4c54bd33a142832b3f4641e49cea57be68 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 12 Jul 2024 14:24:57 +0000 Subject: [PATCH] Test nans_to_nulls in pylibcudf --- python/cudf/cudf/pylibcudf_tests/conftest.py | 11 ++++++- .../cudf/pylibcudf_tests/test_transform.py | 32 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 python/cudf/cudf/pylibcudf_tests/test_transform.py diff --git a/python/cudf/cudf/pylibcudf_tests/conftest.py b/python/cudf/cudf/pylibcudf_tests/conftest.py index 39832eb4bba..b5bbf470eaf 100644 --- a/python/cudf/cudf/pylibcudf_tests/conftest.py +++ b/python/cudf/cudf/pylibcudf_tests/conftest.py @@ -141,6 +141,15 @@ def sorted_opt(request): return request.param -@pytest.fixture(scope="session", params=[False, True]) +@pytest.fixture( + scope="session", params=[False, True], ids=["without_nulls", "with_nulls"] +) def has_nulls(request): return request.param + + +@pytest.fixture( + scope="session", params=[False, True], ids=["without_nans", "with_nans"] +) +def has_nans(request): + return request.param diff --git a/python/cudf/cudf/pylibcudf_tests/test_transform.py b/python/cudf/cudf/pylibcudf_tests/test_transform.py new file mode 100644 index 00000000000..312939888dd --- /dev/null +++ b/python/cudf/cudf/pylibcudf_tests/test_transform.py @@ -0,0 +1,32 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import math + +import pyarrow as pa +from utils import assert_column_eq + +from cudf._lib import pylibcudf as plc + + +def test_nans_to_nulls(has_nans): + if has_nans: + values = [1, float("nan"), float("nan"), None, 3, None] + else: + values = [1, 4, 5, None, 3, None] + + replaced = [ + None if (v is None or (v is not None and math.isnan(v))) else v + for v in values + ] + + h_input = pa.array(values, type=pa.float32()) + input = plc.interop.from_arrow(h_input) + assert input.null_count() == h_input.null_count + expect = pa.array(replaced, type=pa.float32()) + + mask, null_count = plc.transform.nans_to_nulls(input) + + assert null_count == expect.null_count + got = input.with_mask(mask, null_count) + + assert_column_eq(expect, got)