diff --git a/cpp/include/cudf/strings/findall.hpp b/cpp/include/cudf/strings/findall.hpp index 4207cddbafb..1cb742ec09e 100644 --- a/cpp/include/cudf/strings/findall.hpp +++ b/cpp/include/cudf/strings/findall.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include @@ -47,14 +48,16 @@ namespace strings { * * See the @ref md_regex "Regex Features" page for details on patterns supported by this API. * - * @param strings Strings instance for this operation. + * @param input Strings instance for this operation. * @param pattern Regex pattern to match within each string. + * @param flags Regex flags for interpreting special characters in the pattern. * @param mr Device memory resource used to allocate the returned table's device memory. * @return New table of strings columns. */ std::unique_ptr findall( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, + regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -77,14 +80,16 @@ std::unique_ptr
findall( * * See the @ref md_regex "Regex Features" page for details on patterns supported by this API. * - * @param strings Strings instance for this operation. + * @param input Strings instance for this operation. * @param pattern Regex pattern to match within each string. + * @param flags Regex flags for interpreting special characters in the pattern. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New lists column of strings. */ std::unique_ptr findall_record( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, + regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @} */ // end of doxygen group diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 8fb754848d4..810e44cc27d 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -110,17 +110,18 @@ struct findall_count_fn : public findall_fn { // std::unique_ptr
findall( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, + regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto const strings_count = strings.size(); - auto const d_strings = column_device_view::create(strings.parent(), stream); + auto const strings_count = input.size(); + auto const d_strings = column_device_view::create(input.parent(), stream); - auto const d_flags = detail::get_character_flags_table(); // compile regex into device object - auto const d_prog = reprog_device::create(pattern, d_flags, strings_count, stream); + auto const d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); auto const regex_insts = d_prog->insts_counts(); rmm::device_uvector find_counts(strings_count, stream); @@ -205,12 +206,13 @@ std::unique_ptr
findall( // external API -std::unique_ptr
findall(strings_column_view const& strings, +std::unique_ptr
findall(strings_column_view const& input, std::string const& pattern, + regex_flags const flags, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::findall(strings, pattern, rmm::cuda_stream_default, mr); + return detail::findall(input, pattern, flags, rmm::cuda_stream_default, mr); } } // namespace strings diff --git a/cpp/src/strings/search/findall_record.cu b/cpp/src/strings/search/findall_record.cu index 9ffdb33f5f2..c93eb0c17db 100644 --- a/cpp/src/strings/search/findall_record.cu +++ b/cpp/src/strings/search/findall_record.cu @@ -79,17 +79,18 @@ struct findall_fn { // std::unique_ptr findall_record( - strings_column_view const& strings, + strings_column_view const& input, std::string const& pattern, + regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - auto const strings_count = strings.size(); - auto const d_strings = column_device_view::create(strings.parent(), stream); + auto const strings_count = input.size(); + auto const d_strings = column_device_view::create(input.parent(), stream); // compile regex into device object auto const d_prog = - reprog_device::create(pattern, get_character_flags_table(), strings_count, stream); + reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); // Create lists offsets column auto offsets = count_matches(*d_strings, *d_prog, stream, mr); @@ -159,12 +160,13 @@ std::unique_ptr findall_record( // external API -std::unique_ptr findall_record(strings_column_view const& strings, +std::unique_ptr findall_record(strings_column_view const& input, std::string const& pattern, + regex_flags const flags, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::findall_record(strings, pattern, rmm::cuda_stream_default, mr); + return detail::findall_record(input, pattern, flags, rmm::cuda_stream_default, mr); } } // namespace strings diff --git a/cpp/tests/strings/findall_tests.cpp b/cpp/tests/strings/findall_tests.cpp index 4b1305a870a..a4a28f31ce2 100644 --- a/cpp/tests/strings/findall_tests.cpp +++ b/cpp/tests/strings/findall_tests.cpp @@ -97,6 +97,51 @@ TEST_F(StringsFindallTests, FindallRecord) CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected); } +TEST_F(StringsFindallTests, Multiline) +{ + cudf::test::strings_column_wrapper input({"abc\nfff\nabc", "fff\nabc\nlll", "abc", "", "abc\n"}); + auto view = cudf::strings_column_view(input); + + { + auto results = cudf::strings::findall(view, "(^abc$)", cudf::strings::regex_flags::MULTILINE); + auto col0 = + cudf::test::strings_column_wrapper({"abc", "abc", "abc", "", "abc"}, {1, 1, 1, 0, 1}); + auto col1 = cudf::test::strings_column_wrapper({"abc", "", "", "", ""}, {1, 0, 0, 0, 0}); + auto expected = cudf::table_view({col0, col1}); + CUDF_TEST_EXPECT_TABLES_EQUAL(results->view(), expected); + } + { + auto results = + cudf::strings::findall_record(view, "(^abc$)", cudf::strings::regex_flags::MULTILINE); + bool valids[] = {1, 1, 1, 0, 1}; + using LCW = cudf::test::lists_column_wrapper; + LCW expected({LCW{"abc", "abc"}, LCW{"abc"}, LCW{"abc"}, LCW{}, LCW{"abc"}}, valids); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected); + } +} + +TEST_F(StringsFindallTests, DotAll) +{ + cudf::test::strings_column_wrapper input({"abc\nfa\nef", "fff\nabbc\nfff", "abcdef", ""}); + auto view = cudf::strings_column_view(input); + + { + auto results = cudf::strings::findall(view, "(b.*f)", cudf::strings::regex_flags::DOTALL); + auto col0 = + cudf::test::strings_column_wrapper({"bc\nfa\nef", "bbc\nfff", "bcdef", ""}, {1, 1, 1, 0}); + auto expected = cudf::table_view({col0}); + CUDF_TEST_EXPECT_TABLES_EQUAL(results->view(), expected); + } + { + auto results = + cudf::strings::findall_record(view, "(b.*f)", cudf::strings::regex_flags::DOTALL); + bool valids[] = {1, 1, 1, 0}; + using LCW = cudf::test::lists_column_wrapper; + LCW expected({LCW{"bc\nfa\nef"}, LCW{"bbc\nfff"}, LCW{"bcdef"}, LCW{}}, valids); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected); + } +} + TEST_F(StringsFindallTests, MediumRegex) { // This results in 15 regex instructions and falls in the 'medium' range. diff --git a/python/cudf/cudf/_lib/cpp/strings/findall.pxd b/python/cudf/cudf/_lib/cpp/strings/findall.pxd index 5533467d72a..5edb792831b 100644 --- a/python/cudf/cudf/_lib/cpp/strings/findall.pxd +++ b/python/cudf/cudf/_lib/cpp/strings/findall.pxd @@ -5,6 +5,7 @@ from libcpp.string cimport string from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view +from cudf._lib.cpp.strings.contains cimport regex_flags from cudf._lib.cpp.table.table cimport table @@ -12,8 +13,10 @@ cdef extern from "cudf/strings/findall.hpp" namespace "cudf::strings" nogil: cdef unique_ptr[table] findall( const column_view& source_strings, - const string& pattern) except + + const string& pattern, + regex_flags flags) except + cdef unique_ptr[column] findall_record( const column_view& source_strings, - const string& pattern) except + + const string& pattern, + regex_flags flags) except + diff --git a/python/cudf/cudf/_lib/strings/findall.pyx b/python/cudf/cudf/_lib/strings/findall.pyx index b17988018a6..c4e4b6c38d8 100644 --- a/python/cudf/cudf/_lib/strings/findall.pyx +++ b/python/cudf/cudf/_lib/strings/findall.pyx @@ -1,5 +1,6 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. +from libc.stdint cimport uint32_t from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.utility cimport move @@ -8,6 +9,7 @@ from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.scalar.scalar cimport string_scalar +from cudf._lib.cpp.strings.contains cimport regex_flags from cudf._lib.cpp.strings.findall cimport ( findall as cpp_findall, findall_record as cpp_findall_record, @@ -17,7 +19,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.utils cimport data_from_unique_ptr -def findall(Column source_strings, pattern): +def findall(Column source_strings, object pattern, uint32_t flags): """ Returns data with all non-overlapping matches of `pattern` in each string of `source_strings`. @@ -26,11 +28,13 @@ def findall(Column source_strings, pattern): cdef column_view source_view = source_strings.view() cdef string pattern_string = str(pattern).encode() + cdef regex_flags c_flags = flags with nogil: c_result = move(cpp_findall( source_view, - pattern_string + pattern_string, + c_flags )) return data_from_unique_ptr( @@ -39,7 +43,7 @@ def findall(Column source_strings, pattern): ) -def findall_record(Column source_strings, pattern): +def findall_record(Column source_strings, object pattern, uint32_t flags): """ Returns data with all non-overlapping matches of `pattern` in each string of `source_strings` as a lists column. @@ -48,11 +52,13 @@ def findall_record(Column source_strings, pattern): cdef column_view source_view = source_strings.view() cdef string pattern_string = str(pattern).encode() + cdef regex_flags c_flags = flags with nogil: c_result = move(cpp_findall_record( source_view, - pattern_string + pattern_string, + c_flags )) return Column.from_unique_ptr(move(c_result)) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 22b7a0f9d2c..ee1ddb58abc 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -3410,6 +3410,8 @@ def findall( ---------- pat : str Pattern or regular expression. + flags : int, default 0 (no flags) + Flags to pass through to the regex engine (e.g. re.MULTILINE) Returns ------- @@ -3419,7 +3421,8 @@ def findall( Notes ----- - `flags` parameter is currently not supported. + The `flags` parameter currently only supports re.DOTALL and + re.MULTILINE. Examples -------- @@ -3462,10 +3465,15 @@ def findall( 1 2 b b """ - if flags != 0: - raise NotImplementedError("`flags` parameter is not yet supported") + if isinstance(pat, re.Pattern): + flags = pat.flags & ~re.U + pat = pat.pattern + if not _is_supported_regex_flags(flags): + raise NotImplementedError( + "unsupported value for `flags` parameter" + ) - data, index = libstrings.findall(self._column, pat) + data, index = libstrings.findall(self._column, pat, flags) return self._return_or_inplace( cudf.core.frame.Frame(data, index), expand=expand ) diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index 75cf2e6c892..653c79fe603 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -1775,14 +1775,23 @@ def test_string_count(data, pat, flags): def test_string_findall(): - ps = pd.Series(["Lion", "Monkey", "Rabbit"]) - gs = cudf.Series(["Lion", "Monkey", "Rabbit"]) + test_data = ["Lion", "Monkey", "Rabbit", "Don\nkey"] + ps = pd.Series(test_data) + gs = cudf.Series(test_data) assert_eq(ps.str.findall("Monkey")[1][0], gs.str.findall("Monkey")[0][1]) assert_eq(ps.str.findall("on")[0][0], gs.str.findall("on")[0][0]) assert_eq(ps.str.findall("on")[1][0], gs.str.findall("on")[0][1]) - assert_eq(ps.str.findall("on$")[0][0], gs.str.findall("on$")[0][0]) assert_eq(ps.str.findall("b")[2][1], gs.str.findall("b")[1][2]) + assert_eq(ps.str.findall("on$")[0][0], gs.str.findall("on$")[0][0]) + assert_eq( + ps.str.findall("on$", re.MULTILINE)[3][0], + gs.str.findall("on$", re.MULTILINE)[0][3], + ) + assert_eq( + ps.str.findall("o.*k", re.DOTALL)[3][0], + gs.str.findall("o.*k", re.DOTALL)[0][3], + ) def test_string_replace_multi():