From 8d80d5c960d3f396ac2a7841d6d8a07d9528f175 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Tue, 19 Jan 2021 13:37:25 -0600 Subject: [PATCH] Fix comparisons between Series and cudf.NA(#7072) Fixes https://github.com/rapidsai/cudf/issues/7043, gives less than ideal results due to https://github.com/rapidsai/cudf/issues/7066. Authors: - brandon-b-miller Approvers: - GALI PREM SAGAR (@galipremsagar) URL: https://github.com/rapidsai/cudf/pull/7072 --- python/cudf/cudf/core/column/string.py | 2 +- python/cudf/cudf/core/series.py | 2 ++ python/cudf/cudf/tests/test_binops.py | 40 ++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index ea14f23ea44..f5df440b865 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -4966,7 +4966,7 @@ def binary_operator(self, op, rhs, reflect=False): lhs = self if reflect: lhs, rhs = rhs, lhs - if isinstance(rhs, (StringColumn, str)): + if isinstance(rhs, (StringColumn, str, cudf.Scalar)): if op == "add": return lhs.str().cat(others=rhs) elif op in ("eq", "ne", "gt", "lt", "ge", "le"): diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 9602cf8d473..2b9078abed6 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -1569,6 +1569,8 @@ def _normalize_binop_value(self, other): return other._column elif isinstance(other, Index): return Series(other)._column + elif other is cudf.NA: + return cudf.Scalar(other, dtype=self.dtype) else: return self._column.normalize_binop_value(other) diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index be1cef50ec3..93bc6d1c573 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -19,6 +19,7 @@ DATETIME_TYPES, FLOAT_TYPES, INTEGER_TYPES, + NUMERIC_TYPES, TIMEDELTA_TYPES, ) @@ -1527,3 +1528,42 @@ def test_binops_with_lhs_numpy_scalar(frame, dtype): expected = pd.Index(expected) utils.assert_eq(expected, got) + + +@pytest.mark.parametrize( + "dtype", + [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "datetime64[ns]", + "datetime64[us]", + "datetime64[ms]", + "datetime64[s]", + "timedelta64[ns]", + "timedelta64[us]", + "timedelta64[ms]", + "timedelta64[s]", + ], +) +@pytest.mark.parametrize("op", _operators_comparison) +def test_binops_with_NA_consistent(dtype, op): + data = [1, 2, 3] + sr = cudf.Series(data, dtype=dtype) + + result = getattr(sr, op)(cudf.NA) + if dtype in NUMERIC_TYPES: + if op == "ne": + expect_all = True + else: + expect_all = False + assert (result == expect_all).all() + elif dtype in DATETIME_TYPES & TIMEDELTA_TYPES: + assert result._column.null_count == len(data)