From 31361242612a2f1198f1defb64cd560ee4eecfa8 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Wed, 24 Mar 2021 18:35:41 -0500 Subject: [PATCH] Fix comparison between Datetime/Timedelta columns and NULL scalars (#7504) Fixes https://github.com/rapidsai/cudf/issues/6897 Authors: - @brandon-b-miller Approvers: - GALI PREM SAGAR (@galipremsagar) - Ram (Ramakrishna Prabhu) (@rgsl888prabhu) URL: https://github.com/rapidsai/cudf/pull/7504 --- python/cudf/cudf/core/column/datetime.py | 2 + python/cudf/cudf/core/column/timedelta.py | 2 + python/cudf/cudf/tests/test_binops.py | 45 +++++++++++++++++++++++ 3 files changed, 49 insertions(+) diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index a563248f4ab..0bacbe04356 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -178,6 +178,8 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike: return cudf.Scalar(None, dtype=other.dtype) return cudf.Scalar(other) + elif other is None: + return cudf.Scalar(other, dtype=self.dtype) else: raise TypeError(f"cannot normalize {type(other)}") diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index e22b511db01..a39638106bb 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -275,6 +275,8 @@ def normalize_binop_value(self, other) -> BinaryOperand: return cudf.Scalar(other) elif np.isscalar(other): return cudf.Scalar(other) + elif other is None: + return cudf.Scalar(other, dtype=self.dtype) else: raise TypeError(f"cannot normalize {type(other)}") diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 18f2d7e474b..eb8aaaadd51 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -1773,6 +1773,51 @@ def decimal_series(input, dtype): utils.assert_eq(expect, got) +@pytest.mark.parametrize( + "dtype", + [ + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + "str", + "datetime64[ns]", + "datetime64[us]", + "datetime64[ms]", + "datetime64[s]", + "timedelta64[ns]", + "timedelta64[us]", + "timedelta64[ms]", + "timedelta64[s]", + ], +) +@pytest.mark.parametrize("null_scalar", [None, cudf.NA, np.datetime64("NaT")]) +@pytest.mark.parametrize("cmpop", _cmpops) +def test_column_null_scalar_comparison(dtype, null_scalar, cmpop): + # This test is meant to validate that comparing + # a series of any dtype with a null scalar produces + # a new series where all the elements are . + + if isinstance(null_scalar, np.datetime64): + if np.dtype(dtype).kind not in "mM": + pytest.skip() + null_scalar = null_scalar.astype(dtype) + + dtype = np.dtype(dtype) + + data = [1, 2, 3, 4, 5] + sr = cudf.Series(data, dtype=dtype) + result = cmpop(sr, null_scalar) + + assert result.isnull().all() + + @pytest.mark.parametrize("fn", ["eq", "ne", "lt", "gt", "le", "ge"]) def test_equality_ops_index_mismatch(fn): a = cudf.Series(