Skip to content

Commit

Permalink
Fix comparisons between Series and cudf.NA(#7072)
Browse files Browse the repository at this point in the history
Fixes #7043, gives less than ideal results due to #7066.

Authors:
  - brandon-b-miller <[email protected]>

Approvers:
  - GALI PREM SAGAR (@galipremsagar)

URL: #7072
  • Loading branch information
brandon-b-miller authored Jan 19, 2021
1 parent e8ecb24 commit 8d80d5c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -4966,7 +4966,7 @@ def binary_operator(self, op, rhs, reflect=False):
lhs = self
if reflect:
lhs, rhs = rhs, lhs
if isinstance(rhs, (StringColumn, str)):
if isinstance(rhs, (StringColumn, str, cudf.Scalar)):
if op == "add":
return lhs.str().cat(others=rhs)
elif op in ("eq", "ne", "gt", "lt", "ge", "le"):
Expand Down
2 changes: 2 additions & 0 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,8 @@ def _normalize_binop_value(self, other):
return other._column
elif isinstance(other, Index):
return Series(other)._column
elif other is cudf.NA:
return cudf.Scalar(other, dtype=self.dtype)
else:
return self._column.normalize_binop_value(other)

Expand Down
40 changes: 40 additions & 0 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DATETIME_TYPES,
FLOAT_TYPES,
INTEGER_TYPES,
NUMERIC_TYPES,
TIMEDELTA_TYPES,
)

Expand Down Expand Up @@ -1527,3 +1528,42 @@ def test_binops_with_lhs_numpy_scalar(frame, dtype):
expected = pd.Index(expected)

utils.assert_eq(expected, got)


@pytest.mark.parametrize(
"dtype",
[
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float32",
"float64",
"datetime64[ns]",
"datetime64[us]",
"datetime64[ms]",
"datetime64[s]",
"timedelta64[ns]",
"timedelta64[us]",
"timedelta64[ms]",
"timedelta64[s]",
],
)
@pytest.mark.parametrize("op", _operators_comparison)
def test_binops_with_NA_consistent(dtype, op):
data = [1, 2, 3]
sr = cudf.Series(data, dtype=dtype)

result = getattr(sr, op)(cudf.NA)
if dtype in NUMERIC_TYPES:
if op == "ne":
expect_all = True
else:
expect_all = False
assert (result == expect_all).all()
elif dtype in DATETIME_TYPES & TIMEDELTA_TYPES:
assert result._column.null_count == len(data)

0 comments on commit 8d80d5c

Please sign in to comment.