Skip to content

Commit

Permalink
Support the case=False argument to str.contains (#13290)
Browse files Browse the repository at this point in the history
Closes #13253

Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #13290
  • Loading branch information
shwina authored May 10, 2023
1 parent 3d814bd commit 6f3f507
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
35 changes: 21 additions & 14 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def contains(
4 False
dtype: bool
The ``pat`` may also be a list of strings in which case
The ``pat`` may also be a sequence of strings in which case
the individual strings are searched in corresponding rows.
>>> s2 = cudf.Series(['house', 'dog', 'and', '', ''])
Expand All @@ -756,8 +756,6 @@ def contains(
4 <NA>
dtype: bool
""" # noqa W605
if case is not True:
raise NotImplementedError("`case` parameter is not yet supported")
if na is not np.nan:
raise NotImplementedError("`na` parameter is not yet supported")
if regex and isinstance(pat, re.Pattern):
Expand All @@ -767,22 +765,31 @@ def contains(
raise NotImplementedError(
"unsupported value for `flags` parameter"
)

if pat is None:
result_col = column.column_empty(
len(self._column), dtype="bool", masked=True
if regex and not case:
raise NotImplementedError(
"`case=False` only supported when `regex=False`"
)
elif is_scalar(pat):

if is_scalar(pat):
if regex:
result_col = libstrings.contains_re(self._column, pat, flags)
else:
result_col = libstrings.contains(
self._column, cudf.Scalar(pat, "str")
)
if case is False:
input_column = libstrings.to_lower(self._column)
pat = cudf.Scalar(pat.lower(), dtype="str") # type: ignore
else:
input_column = self._column
pat = cudf.Scalar(pat, dtype="str") # type: ignore
result_col = libstrings.contains(input_column, pat)
else:
result_col = libstrings.contains_multiple(
self._column, column.as_column(pat, dtype="str")
)
# TODO: we silently ignore the `regex=` flag here
if case is False:
input_column = libstrings.to_lower(self._column)
pat = libstrings.to_lower(column.as_column(pat, dtype="str"))
else:
input_column = self._column
pat = column.as_column(pat, dtype="str")
result_col = libstrings.contains_multiple(input_column, pat)
return self._return_or_inplace(result_col)

def like(self, pat: str, esc: str = None) -> SeriesOrIndex:
Expand Down
11 changes: 11 additions & 0 deletions python/cudf/cudf/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,17 @@ def test_string_contains(ps_gs, pat, regex, flags, flags_raise, na, na_raise):
assert_eq(expect, got)


def test_string_contains_case(ps_gs):
ps, gs = ps_gs
with pytest.raises(NotImplementedError):
gs.str.contains("A", case=False)
expected = ps.str.contains("A", regex=False, case=False)
got = gs.str.contains("A", regex=False, case=False)
assert_eq(expected, got)
got = gs.str.contains("a", regex=False, case=False)
assert_eq(expected, got)


@pytest.mark.parametrize(
"pat,esc,expect",
[
Expand Down

0 comments on commit 6f3f507

Please sign in to comment.