Skip to content

Commit

Permalink
Fix Series comparison vs scalars (#12519)
Browse files Browse the repository at this point in the history
Fixes an issue where this happens:

```python
import cudf
cudf.Series(['a','b','c']) == 1
```
```
  File "/raid/brmiller/anaconda/envs/cudf_dev/lib/python3.9/site-packages/cudf/core/mixins/mixin_factory.py", line 11, in wrapper
    return method(self, *args1, *args2, **kwargs1, **kwargs2)
  File "/raid/brmiller/anaconda/envs/cudf_dev/lib/python3.9/site-packages/cudf/core/indexed_frame.py", line 3278, in _binaryop
    ColumnAccessor(type(self)._colwise_binop(operands, op)),
  File "/raid/brmiller/anaconda/envs/cudf_dev/lib/python3.9/site-packages/cudf/core/column_accessor.py", line 124, in __init__
    column_length = len(data[next(iter(data))])
TypeError: object of type 'bool' has no len()
```

It turns out this happens because `StringColumn`'s `normalize_binop_value` method returns `NotImplemented` for scalars that are not of dtype `object`. This eventually causes python to dispatch to the python scalar class' `__eq__` which returns the scalar `False` when encountering a cuDF object. cuDF expects a column object at this point but has a scalar.

This in turn causes cuDF to try and construct a `ColumnAccessor` around a dict that looks like `{'name', False}` ultimately throwing the error. 

This PR proposes to earlystop this behavior according to the rules for comparing python string scalars with other objects: 
- Always return `False` for `__eq__` even if the character in the string is equivalent to whatever is being compared
- Always return `True` for `__ne__` ditto above. 
- Copy the input mask

This should align us with pandas behavior for this case:

```python
>>> pd.Series(['a','b', 'c'], dtype='string') == 1
0    False
1    False
2    False
dtype: boolean
>>> pd.Series(['a','b', 'c'], dtype='string') != 1
0    True
1    True
2    True
dtype: boolean
```

EDIT:
Updating this PR to handle a similar issue resulting in the same error when comparing datetime series to strings that contain valid datetimes, such as `20110101`.

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #12519
  • Loading branch information
brandon-b-miller authored Feb 10, 2023
1 parent 048f936 commit c4a1389
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 23 deletions.
7 changes: 6 additions & 1 deletion python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2023, NVIDIA CORPORATION.

from __future__ import annotations

Expand Down Expand Up @@ -261,6 +261,11 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
return cudf.Scalar(None, dtype=other.dtype)

return cudf.Scalar(other)
elif isinstance(other, str):
try:
return cudf.Scalar(other, dtype=self.dtype)
except ValueError:
pass

return NotImplemented

Expand Down
13 changes: 12 additions & 1 deletion python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5665,7 +5665,7 @@ def normalize_binop_value(
and other.dtype == "object"
):
return other
if isinstance(other, str):
if is_scalar(other):
return cudf.Scalar(other)
return NotImplemented

Expand Down Expand Up @@ -5701,6 +5701,17 @@ def _binaryop(
return NotImplemented

if isinstance(other, (StringColumn, str, cudf.Scalar)):
if isinstance(other, cudf.Scalar) and other.dtype != "O":
if op in {
"__eq__",
"__ne__",
}:
return column.full(
len(self), op == "__ne__", dtype="bool"
).set_mask(self.mask)
else:
return NotImplemented

if op == "__add__":
if isinstance(other, cudf.Scalar):
other = cast(
Expand Down
81 changes: 60 additions & 21 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2022, NVIDIA CORPORATION.
# Copyright (c) 2018-2023, NVIDIA CORPORATION.

import decimal
import operator
Expand Down Expand Up @@ -320,29 +320,68 @@ def test_series_compare_nulls(cmpop, dtypes):
utils.assert_eq(expect, got)


@pytest.mark.parametrize(
"obj", [pd.Series(["a", "b", None, "d", "e", None], dtype="string"), "a"]
)
@pytest.mark.parametrize("cmpop", _cmpops)
@pytest.mark.parametrize(
"cmp_obj",
[pd.Series(["b", "a", None, "d", "f", None], dtype="string"), "a"],
)
def test_string_series_compare(obj, cmpop, cmp_obj):
@pytest.fixture
def str_series_cmp_data():
return pd.Series(["a", "b", None, "d", "e", None], dtype="string")

g_obj = obj
if isinstance(g_obj, pd.Series):
g_obj = Series.from_pandas(g_obj)
g_cmp_obj = cmp_obj
if isinstance(g_cmp_obj, pd.Series):
g_cmp_obj = Series.from_pandas(g_cmp_obj)
got = cmpop(g_obj, g_cmp_obj)
expected = cmpop(obj, cmp_obj)

if isinstance(expected, pd.Series):
expected = cudf.from_pandas(expected)
@pytest.fixture(ids=[op.__name__ for op in _cmpops], params=_cmpops)
def str_series_compare_str_cmpop(request):
return request.param

utils.assert_eq(expected, got)

@pytest.fixture(ids=["eq", "ne"], params=[operator.eq, operator.ne])
def str_series_compare_num_cmpop(request):
return request.param


@pytest.fixture(ids=["int", "float", "bool"], params=[1, 1.5, True])
def cmp_scalar(request):
return request.param


def test_str_series_compare_str(
str_series_cmp_data, str_series_compare_str_cmpop
):
expect = str_series_compare_str_cmpop(str_series_cmp_data, "a")
got = str_series_compare_str_cmpop(
Series.from_pandas(str_series_cmp_data), "a"
)

utils.assert_eq(expect, got.to_pandas(nullable=True))


def test_str_series_compare_str_reflected(
str_series_cmp_data, str_series_compare_str_cmpop
):
expect = str_series_compare_str_cmpop("a", str_series_cmp_data)
got = str_series_compare_str_cmpop(
"a", Series.from_pandas(str_series_cmp_data)
)

utils.assert_eq(expect, got.to_pandas(nullable=True))


def test_str_series_compare_num(
str_series_cmp_data, str_series_compare_num_cmpop, cmp_scalar
):
expect = str_series_compare_num_cmpop(str_series_cmp_data, cmp_scalar)
got = str_series_compare_num_cmpop(
Series.from_pandas(str_series_cmp_data), cmp_scalar
)

utils.assert_eq(expect, got.to_pandas(nullable=True))


def test_str_series_compare_num_reflected(
str_series_cmp_data, str_series_compare_num_cmpop, cmp_scalar
):
expect = str_series_compare_num_cmpop(cmp_scalar, str_series_cmp_data)
got = str_series_compare_num_cmpop(
cmp_scalar, Series.from_pandas(str_series_cmp_data)
)

utils.assert_eq(expect, got.to_pandas(nullable=True))


@pytest.mark.parametrize("obj_class", ["Series", "Index"])
Expand Down
26 changes: 26 additions & 0 deletions python/cudf/cudf/tests/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@
expect_warning_if,
)

_cmpops = [
operator.lt,
operator.gt,
operator.le,
operator.ge,
operator.eq,
operator.ne,
]


def data1():
return pd.date_range("20010101", "20020215", freq="400h", name="times")
Expand Down Expand Up @@ -986,6 +995,23 @@ def test_datetime_series_ops_with_scalars(data, other_scalars, dtype, op):
)


@pytest.mark.parametrize("data", ["20110101", "20120101", "20130101"])
@pytest.mark.parametrize("other_scalars", ["20110101", "20120101", "20130101"])
@pytest.mark.parametrize("op", _cmpops)
@pytest.mark.parametrize(
"dtype",
["datetime64[ns]", "datetime64[us]", "datetime64[ms]", "datetime64[s]"],
)
def test_datetime_series_cmpops_with_scalars(data, other_scalars, dtype, op):
gsr = cudf.Series(data=data, dtype=dtype)
psr = gsr.to_pandas()

expect = op(psr, other_scalars)
got = op(gsr, other_scalars)

assert_eq(expect, got)


@pytest.mark.parametrize(
"data",
[
Expand Down

0 comments on commit c4a1389

Please sign in to comment.