Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Series comparison vs scalars #12519

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2023, NVIDIA CORPORATION.

from __future__ import annotations

Expand Down Expand Up @@ -261,6 +261,11 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
return cudf.Scalar(None, dtype=other.dtype)

return cudf.Scalar(other)
elif isinstance(other, str):
try:
return cudf.Scalar(other, dtype=self.dtype)
except ValueError:
pass

return NotImplemented

Expand Down
13 changes: 12 additions & 1 deletion python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5665,7 +5665,7 @@ def normalize_binop_value(
and other.dtype == "object"
):
return other
if isinstance(other, str):
if is_scalar(other):
return cudf.Scalar(other)
return NotImplemented

Expand Down Expand Up @@ -5701,6 +5701,17 @@ def _binaryop(
return NotImplemented

if isinstance(other, (StringColumn, str, cudf.Scalar)):
if isinstance(other, cudf.Scalar) and other.dtype != "O":
if op in {
"__eq__",
"__ne__",
}:
return column.full(
len(self), op == "__ne__", dtype="bool"
).set_mask(self.mask)
else:
return NotImplemented

if op == "__add__":
if isinstance(other, cudf.Scalar):
other = cast(
Expand Down
81 changes: 60 additions & 21 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2022, NVIDIA CORPORATION.
# Copyright (c) 2018-2023, NVIDIA CORPORATION.

import decimal
import operator
Expand Down Expand Up @@ -320,29 +320,68 @@ def test_series_compare_nulls(cmpop, dtypes):
utils.assert_eq(expect, got)


@pytest.mark.parametrize(
"obj", [pd.Series(["a", "b", None, "d", "e", None], dtype="string"), "a"]
)
@pytest.mark.parametrize("cmpop", _cmpops)
@pytest.mark.parametrize(
"cmp_obj",
[pd.Series(["b", "a", None, "d", "f", None], dtype="string"), "a"],
)
def test_string_series_compare(obj, cmpop, cmp_obj):
@pytest.fixture
def str_series_cmp_data():
return pd.Series(["a", "b", None, "d", "e", None], dtype="string")

g_obj = obj
if isinstance(g_obj, pd.Series):
g_obj = Series.from_pandas(g_obj)
g_cmp_obj = cmp_obj
if isinstance(g_cmp_obj, pd.Series):
g_cmp_obj = Series.from_pandas(g_cmp_obj)
got = cmpop(g_obj, g_cmp_obj)
expected = cmpop(obj, cmp_obj)

if isinstance(expected, pd.Series):
expected = cudf.from_pandas(expected)
@pytest.fixture(ids=[op.__name__ for op in _cmpops], params=_cmpops)
def str_series_compare_str_cmpop(request):
return request.param

utils.assert_eq(expected, got)

@pytest.fixture(ids=["eq", "ne"], params=[operator.eq, operator.ne])
def str_series_compare_num_cmpop(request):
return request.param


@pytest.fixture(ids=["int", "float", "bool"], params=[1, 1.5, True])
def cmp_scalar(request):
return request.param


def test_str_series_compare_str(
str_series_cmp_data, str_series_compare_str_cmpop
):
expect = str_series_compare_str_cmpop(str_series_cmp_data, "a")
got = str_series_compare_str_cmpop(
Series.from_pandas(str_series_cmp_data), "a"
)

utils.assert_eq(expect, got.to_pandas(nullable=True))


def test_str_series_compare_str_reflected(
str_series_cmp_data, str_series_compare_str_cmpop
):
expect = str_series_compare_str_cmpop("a", str_series_cmp_data)
got = str_series_compare_str_cmpop(
"a", Series.from_pandas(str_series_cmp_data)
)

utils.assert_eq(expect, got.to_pandas(nullable=True))


def test_str_series_compare_num(
str_series_cmp_data, str_series_compare_num_cmpop, cmp_scalar
):
expect = str_series_compare_num_cmpop(str_series_cmp_data, cmp_scalar)
got = str_series_compare_num_cmpop(
Series.from_pandas(str_series_cmp_data), cmp_scalar
)

utils.assert_eq(expect, got.to_pandas(nullable=True))


def test_str_series_compare_num_reflected(
str_series_cmp_data, str_series_compare_num_cmpop, cmp_scalar
):
expect = str_series_compare_num_cmpop(cmp_scalar, str_series_cmp_data)
got = str_series_compare_num_cmpop(
cmp_scalar, Series.from_pandas(str_series_cmp_data)
)

utils.assert_eq(expect, got.to_pandas(nullable=True))


@pytest.mark.parametrize("obj_class", ["Series", "Index"])
Expand Down
26 changes: 26 additions & 0 deletions python/cudf/cudf/tests/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@
expect_warning_if,
)

_cmpops = [
operator.lt,
operator.gt,
operator.le,
operator.ge,
operator.eq,
operator.ne,
]


def data1():
return pd.date_range("20010101", "20020215", freq="400h", name="times")
Expand Down Expand Up @@ -986,6 +995,23 @@ def test_datetime_series_ops_with_scalars(data, other_scalars, dtype, op):
)


@pytest.mark.parametrize("data", ["20110101", "20120101", "20130101"])
@pytest.mark.parametrize("other_scalars", ["20110101", "20120101", "20130101"])
@pytest.mark.parametrize("op", _cmpops)
@pytest.mark.parametrize(
"dtype",
["datetime64[ns]", "datetime64[us]", "datetime64[ms]", "datetime64[s]"],
)
def test_datetime_series_cmpops_with_scalars(data, other_scalars, dtype, op):
gsr = cudf.Series(data=data, dtype=dtype)
psr = gsr.to_pandas()

expect = op(psr, other_scalars)
got = op(gsr, other_scalars)

assert_eq(expect, got)


@pytest.mark.parametrize(
"data",
[
Expand Down