From 6f3f50721e585b55bd263cc36926eb2c99c5f811 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Wed, 10 May 2023 15:19:07 -0400 Subject: [PATCH] Support the case=False argument to str.contains (#13290) Closes https://github.com/rapidsai/cudf/issues/13253 Authors: - Ashwin Srinath (https://github.com/shwina) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/13290 --- python/cudf/cudf/core/column/string.py | 35 +++++++++++++++----------- python/cudf/cudf/tests/test_string.py | 11 ++++++++ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index fefa7beb562..1a09fc0b985 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -744,7 +744,7 @@ def contains( 4 False dtype: bool - The ``pat`` may also be a list of strings in which case + The ``pat`` may also be a sequence of strings in which case the individual strings are searched in corresponding rows. >>> s2 = cudf.Series(['house', 'dog', 'and', '', '']) @@ -756,8 +756,6 @@ def contains( 4 dtype: bool """ # noqa W605 - if case is not True: - raise NotImplementedError("`case` parameter is not yet supported") if na is not np.nan: raise NotImplementedError("`na` parameter is not yet supported") if regex and isinstance(pat, re.Pattern): @@ -767,22 +765,31 @@ def contains( raise NotImplementedError( "unsupported value for `flags` parameter" ) - - if pat is None: - result_col = column.column_empty( - len(self._column), dtype="bool", masked=True + if regex and not case: + raise NotImplementedError( + "`case=False` only supported when `regex=False`" ) - elif is_scalar(pat): + + if is_scalar(pat): if regex: result_col = libstrings.contains_re(self._column, pat, flags) else: - result_col = libstrings.contains( - self._column, cudf.Scalar(pat, "str") - ) + if case is False: + input_column = libstrings.to_lower(self._column) + pat = cudf.Scalar(pat.lower(), dtype="str") # type: ignore + else: + input_column = self._column + pat = cudf.Scalar(pat, dtype="str") # type: ignore + result_col = libstrings.contains(input_column, pat) else: - result_col = libstrings.contains_multiple( - self._column, column.as_column(pat, dtype="str") - ) + # TODO: we silently ignore the `regex=` flag here + if case is False: + input_column = libstrings.to_lower(self._column) + pat = libstrings.to_lower(column.as_column(pat, dtype="str")) + else: + input_column = self._column + pat = column.as_column(pat, dtype="str") + result_col = libstrings.contains_multiple(input_column, pat) return self._return_or_inplace(result_col) def like(self, pat: str, esc: str = None) -> SeriesOrIndex: diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index c866e064366..12e832ba23b 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -820,6 +820,17 @@ def test_string_contains(ps_gs, pat, regex, flags, flags_raise, na, na_raise): assert_eq(expect, got) +def test_string_contains_case(ps_gs): + ps, gs = ps_gs + with pytest.raises(NotImplementedError): + gs.str.contains("A", case=False) + expected = ps.str.contains("A", regex=False, case=False) + got = gs.str.contains("A", regex=False, case=False) + assert_eq(expected, got) + got = gs.str.contains("a", regex=False, case=False) + assert_eq(expected, got) + + @pytest.mark.parametrize( "pat,esc,expect", [