Skip to content

Commit

Permalink
Fix comparison between Datetime/Timedelta columns and NULL scalars (#…
Browse files Browse the repository at this point in the history
…7504)

Fixes #6897

Authors:
  - @brandon-b-miller

Approvers:
  - GALI PREM SAGAR (@galipremsagar)
  - Ram (Ramakrishna Prabhu) (@rgsl888prabhu)

URL: #7504
  • Loading branch information
brandon-b-miller authored Mar 24, 2021
1 parent f38daf3 commit 3136124
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
return cudf.Scalar(None, dtype=other.dtype)

return cudf.Scalar(other)
elif other is None:
return cudf.Scalar(other, dtype=self.dtype)
else:
raise TypeError(f"cannot normalize {type(other)}")

Expand Down
2 changes: 2 additions & 0 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def normalize_binop_value(self, other) -> BinaryOperand:
return cudf.Scalar(other)
elif np.isscalar(other):
return cudf.Scalar(other)
elif other is None:
return cudf.Scalar(other, dtype=self.dtype)
else:
raise TypeError(f"cannot normalize {type(other)}")

Expand Down
45 changes: 45 additions & 0 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,51 @@ def decimal_series(input, dtype):
utils.assert_eq(expect, got)


@pytest.mark.parametrize(
"dtype",
[
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float32",
"float64",
"str",
"datetime64[ns]",
"datetime64[us]",
"datetime64[ms]",
"datetime64[s]",
"timedelta64[ns]",
"timedelta64[us]",
"timedelta64[ms]",
"timedelta64[s]",
],
)
@pytest.mark.parametrize("null_scalar", [None, cudf.NA, np.datetime64("NaT")])
@pytest.mark.parametrize("cmpop", _cmpops)
def test_column_null_scalar_comparison(dtype, null_scalar, cmpop):
# This test is meant to validate that comparing
# a series of any dtype with a null scalar produces
# a new series where all the elements are <NA>.

if isinstance(null_scalar, np.datetime64):
if np.dtype(dtype).kind not in "mM":
pytest.skip()
null_scalar = null_scalar.astype(dtype)

dtype = np.dtype(dtype)

data = [1, 2, 3, 4, 5]
sr = cudf.Series(data, dtype=dtype)
result = cmpop(sr, null_scalar)

assert result.isnull().all()


@pytest.mark.parametrize("fn", ["eq", "ne", "lt", "gt", "le", "ge"])
def test_equality_ops_index_mismatch(fn):
a = cudf.Series(
Expand Down

0 comments on commit 3136124

Please sign in to comment.