diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index be1cef50ec3..93bc6d1c573 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -19,6 +19,7 @@ DATETIME_TYPES, FLOAT_TYPES, INTEGER_TYPES, + NUMERIC_TYPES, TIMEDELTA_TYPES, ) @@ -1527,3 +1528,42 @@ def test_binops_with_lhs_numpy_scalar(frame, dtype): expected = pd.Index(expected) utils.assert_eq(expected, got) + + +@pytest.mark.parametrize( + "dtype", + [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "datetime64[ns]", + "datetime64[us]", + "datetime64[ms]", + "datetime64[s]", + "timedelta64[ns]", + "timedelta64[us]", + "timedelta64[ms]", + "timedelta64[s]", + ], +) +@pytest.mark.parametrize("op", _operators_comparison) +def test_binops_with_NA_consistent(dtype, op): + data = [1, 2, 3] + sr = cudf.Series(data, dtype=dtype) + + result = getattr(sr, op)(cudf.NA) + if dtype in NUMERIC_TYPES: + if op == "ne": + expect_all = True + else: + expect_all = False + assert (result == expect_all).all() + elif dtype in DATETIME_TYPES & TIMEDELTA_TYPES: + assert result._column.null_count == len(data)