diff --git a/cpp/include/cudf/strings/replace_re.hpp b/cpp/include/cudf/strings/replace_re.hpp index 087d1a94603..a2c4eba1636 100644 --- a/cpp/include/cudf/strings/replace_re.hpp +++ b/cpp/include/cudf/strings/replace_re.hpp @@ -17,6 +17,7 @@ #include #include +#include #include namespace cudf { @@ -37,22 +38,25 @@ namespace strings { * * @param strings Strings instance for this operation. * @param pattern The regular expression pattern to search within each string. - * @param repl The string used to replace the matched sequence in each string. + * @param replacement The string used to replace the matched sequence in each string. * Default is an empty string. - * @param maxrepl The maximum number of times to replace the matched pattern within each string. + * @param max_replace_count The maximum number of times to replace the matched pattern + * within each string. Default replaces every substring that is matched. + * @param flags Regex flags for interpreting special characters in the pattern. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New strings column. */ std::unique_ptr replace_re( strings_column_view const& strings, std::string const& pattern, - string_scalar const& repl = string_scalar(""), - size_type maxrepl = -1, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + string_scalar const& replacement = string_scalar(""), + std::optional max_replace_count = std::nullopt, + regex_flags const flags = regex_flags::DEFAULT, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** * @brief For each string, replaces any character sequence matching the given patterns - * with the corresponding string in the repls column. + * with the corresponding string in the `replacements` column. * * Any null string entries return corresponding null output column entries. * @@ -60,14 +64,16 @@ std::unique_ptr replace_re( * * @param strings Strings instance for this operation. * @param patterns The regular expression patterns to search within each string. - * @param repls The strings used for replacement. + * @param replacements The strings used for replacement. + * @param flags Regex flags for interpreting special characters in the patterns. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New strings column. */ std::unique_ptr replace_re( strings_column_view const& strings, std::vector const& patterns, - strings_column_view const& repls, + strings_column_view const& replacements, + regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -83,6 +89,7 @@ std::unique_ptr replace_re( * @param strings Strings instance for this operation. * @param pattern The regular expression patterns to search within each string. * @param replacement The replacement template for creating the output string. + * @param flags Regex flags for interpreting special characters in the pattern. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New strings column. */ @@ -90,6 +97,7 @@ std::unique_ptr replace_with_backrefs( strings_column_view const& strings, std::string const& pattern, std::string const& replacement, + regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); } // namespace strings diff --git a/cpp/src/strings/replace/backref_re.cu b/cpp/src/strings/replace/backref_re.cu index 99c55998fb9..ff86d7aa552 100644 --- a/cpp/src/strings/replace/backref_re.cu +++ b/cpp/src/strings/replace/backref_re.cu @@ -101,22 +101,24 @@ std::pair> parse_backrefs(std::string con std::unique_ptr replace_with_backrefs( strings_column_view const& strings, std::string const& pattern, - std::string const& repl, + std::string const& replacement, + regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { if (strings.is_empty()) return make_empty_column(type_id::STRING); CUDF_EXPECTS(!pattern.empty(), "Parameter pattern must not be empty"); - CUDF_EXPECTS(!repl.empty(), "Parameter repl must not be empty"); + CUDF_EXPECTS(!replacement.empty(), "Parameter replacement must not be empty"); auto d_strings = column_device_view::create(strings.parent(), stream); // compile regex into device object - auto d_prog = reprog_device::create(pattern, get_character_flags_table(), strings.size(), stream); + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), strings.size(), stream); auto const regex_insts = d_prog->insts_counts(); // parse the repl string for back-ref indicators - auto const parse_result = parse_backrefs(repl); + auto const parse_result = parse_backrefs(replacement); rmm::device_uvector backrefs = cudf::detail::make_device_uvector_async(parse_result.second, stream); string_scalar repl_scalar(parse_result.first, true, stream); @@ -170,11 +172,13 @@ std::unique_ptr replace_with_backrefs( std::unique_ptr replace_with_backrefs(strings_column_view const& strings, std::string const& pattern, - std::string const& repl, + std::string const& replacement, + regex_flags const flags, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::replace_with_backrefs(strings, pattern, repl, rmm::cuda_stream_default, mr); + return detail::replace_with_backrefs( + strings, pattern, replacement, flags, rmm::cuda_stream_default, mr); } } // namespace strings diff --git a/cpp/src/strings/replace/multi_re.cu b/cpp/src/strings/replace/multi_re.cu index 25417909c89..2b5380b76dd 100644 --- a/cpp/src/strings/replace/multi_re.cu +++ b/cpp/src/strings/replace/multi_re.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -53,12 +54,11 @@ using found_range = thrust::pair; template struct replace_multi_regex_fn { column_device_view const d_strings; - reprog_device* progs; // array of regex progs - size_type number_of_patterns; - found_range* d_found_ranges; // working array matched (begin,end) values - column_device_view const d_repls; // replacement strings - int32_t* d_offsets{}; // these are null when - char* d_chars{}; // only computing size + device_span progs; // array of regex progs + found_range* d_found_ranges; // working array matched (begin,end) values + column_device_view const d_repls; // replacement strings + int32_t* d_offsets{}; + char* d_chars{}; __device__ void operator()(size_type idx) { @@ -66,6 +66,9 @@ struct replace_multi_regex_fn { if (!d_chars) d_offsets[idx] = 0; return; } + + auto const number_of_patterns = static_cast(progs.size()); + auto const d_str = d_strings.element(idx); auto const nchars = d_str.length(); // number of characters in input string auto nbytes = d_str.size_bytes(); // number of bytes in input string @@ -129,7 +132,8 @@ struct replace_multi_regex_fn { std::unique_ptr replace_re( strings_column_view const& strings, std::vector const& patterns, - strings_column_view const& repls, + strings_column_view const& replacements, + regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { @@ -138,31 +142,25 @@ std::unique_ptr replace_re( if (patterns.empty()) // no patterns; just return a copy return std::make_unique(strings.parent(), stream, mr); - CUDF_EXPECTS(!repls.has_nulls(), "Parameter repls must not have any nulls"); + CUDF_EXPECTS(!replacements.has_nulls(), "Parameter replacements must not have any nulls"); - auto d_strings = column_device_view::create(strings.parent(), stream); - auto d_repls = column_device_view::create(repls.parent(), stream); - auto d_flags = get_character_flags_table(); + auto d_strings = column_device_view::create(strings.parent(), stream); + auto d_repls = column_device_view::create(replacements.parent(), stream); + auto d_char_table = get_character_flags_table(); // compile regexes into device objects size_type regex_insts = 0; std::vector>> h_progs; - thrust::host_vector progs; + std::vector progs; for (auto itr = patterns.begin(); itr != patterns.end(); ++itr) { - auto prog = reprog_device::create(*itr, d_flags, strings_count, stream); + auto prog = reprog_device::create(*itr, flags, d_char_table, strings_count, stream); regex_insts = std::max(regex_insts, prog->insts_counts()); progs.push_back(*prog); h_progs.emplace_back(std::move(prog)); } // copy all the reprog_device instances to a device memory array - rmm::device_buffer progs_buffer{sizeof(reprog_device) * progs.size(), stream}; - CUDA_TRY(cudaMemcpyAsync(progs_buffer.data(), - progs.data(), - progs.size() * sizeof(reprog_device), - cudaMemcpyHostToDevice, - stream.value())); - reprog_device* d_progs = reinterpret_cast(progs_buffer.data()); + auto d_progs = cudf::detail::make_device_uvector_async(progs, stream); // create working buffer for ranges pairs rmm::device_uvector found_ranges(patterns.size() * strings_count, stream); @@ -172,34 +170,19 @@ std::unique_ptr replace_re( auto children = [&] { // Each invocation is predicated on the stack size which is dependent on the number of regex // instructions - if (regex_insts <= RX_SMALL_INSTS) - return make_strings_children( - replace_multi_regex_fn{ - *d_strings, d_progs, static_cast(progs.size()), d_found_ranges, *d_repls}, - strings_count, - stream, - mr); - else if (regex_insts <= RX_MEDIUM_INSTS) - return make_strings_children( - replace_multi_regex_fn{ - *d_strings, d_progs, static_cast(progs.size()), d_found_ranges, *d_repls}, - strings_count, - stream, - mr); - else if (regex_insts <= RX_LARGE_INSTS) - return make_strings_children( - replace_multi_regex_fn{ - *d_strings, d_progs, static_cast(progs.size()), d_found_ranges, *d_repls}, - strings_count, - stream, - mr); - else - return make_strings_children( - replace_multi_regex_fn{ - *d_strings, d_progs, static_cast(progs.size()), d_found_ranges, *d_repls}, - strings_count, - stream, - mr); + if (regex_insts <= RX_SMALL_INSTS) { + replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; + return make_strings_children(fn, strings_count, stream, mr); + } else if (regex_insts <= RX_MEDIUM_INSTS) { + replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; + return make_strings_children(fn, strings_count, stream, mr); + } else if (regex_insts <= RX_LARGE_INSTS) { + replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; + return make_strings_children(fn, strings_count, stream, mr); + } else { + replace_multi_regex_fn fn{*d_strings, d_progs, d_found_ranges, *d_repls}; + return make_strings_children(fn, strings_count, stream, mr); + } }(); return make_strings_column(strings_count, @@ -215,11 +198,12 @@ std::unique_ptr replace_re( std::unique_ptr replace_re(strings_column_view const& strings, std::vector const& patterns, - strings_column_view const& repls, + strings_column_view const& replacements, + regex_flags const flags, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::replace_re(strings, patterns, repls, rmm::cuda_stream_default, mr); + return detail::replace_re(strings, patterns, replacements, flags, rmm::cuda_stream_default, mr); } } // namespace strings diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index b940944c186..9fd1768453a 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -52,7 +52,7 @@ struct replace_regex_fn { column_device_view const d_strings; reprog_device prog; string_view const d_repl; - size_type maxrepl; + size_type const maxrepl; int32_t* d_offsets{}; char* d_chars{}; @@ -102,56 +102,48 @@ struct replace_regex_fn { std::unique_ptr replace_re( strings_column_view const& strings, std::string const& pattern, - string_scalar const& repl = string_scalar(""), - size_type maxrepl = -1, + string_scalar const& replacement, + std::optional max_replace_count, + regex_flags const flags, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { auto strings_count = strings.size(); if (strings_count == 0) return make_empty_column(type_id::STRING); - CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid"); - string_view d_repl(repl.data(), repl.size()); + CUDF_EXPECTS(replacement.is_valid(stream), "Parameter replacement must be valid"); + string_view d_repl(replacement.data(), replacement.size()); auto strings_column = column_device_view::create(strings.parent(), stream); auto d_strings = *strings_column; // compile regex into device object - auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream); - auto d_prog = *prog; - auto regex_insts = d_prog.insts_counts(); + auto prog = + reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); + auto d_prog = *prog; + auto const regex_insts = d_prog.insts_counts(); // copy null mask - auto null_mask = cudf::detail::copy_bitmask(strings.parent(), stream, mr); - auto null_count = strings.null_count(); + auto null_mask = cudf::detail::copy_bitmask(strings.parent(), stream, mr); + auto const null_count = strings.null_count(); + auto const maxrepl = max_replace_count.value_or(-1); // create child columns auto children = [&] { // Each invocation is predicated on the stack size which is dependent on the number of regex // instructions - if (regex_insts <= RX_SMALL_INSTS) - return make_strings_children( - replace_regex_fn{d_strings, d_prog, d_repl, maxrepl}, - strings_count, - stream, - mr); - else if (regex_insts <= RX_MEDIUM_INSTS) - return make_strings_children( - replace_regex_fn{d_strings, d_prog, d_repl, maxrepl}, - strings_count, - stream, - mr); - else if (regex_insts <= RX_LARGE_INSTS) - return make_strings_children( - replace_regex_fn{d_strings, d_prog, d_repl, maxrepl}, - strings_count, - stream, - mr); - else - return make_strings_children( - replace_regex_fn{d_strings, d_prog, d_repl, maxrepl}, - strings_count, - stream, - mr); + if (regex_insts <= RX_SMALL_INSTS) { + replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; + return make_strings_children(fn, strings_count, stream, mr); + } else if (regex_insts <= RX_MEDIUM_INSTS) { + replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; + return make_strings_children(fn, strings_count, stream, mr); + } else if (regex_insts <= RX_LARGE_INSTS) { + replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; + return make_strings_children(fn, strings_count, stream, mr); + } else { + replace_regex_fn fn{d_strings, d_prog, d_repl, maxrepl}; + return make_strings_children(fn, strings_count, stream, mr); + } }(); return make_strings_column(strings_count, @@ -167,12 +159,14 @@ std::unique_ptr replace_re( std::unique_ptr replace_re(strings_column_view const& strings, std::string const& pattern, - string_scalar const& repl, - size_type maxrepl, + string_scalar const& replacement, + std::optional max_replace_count, + regex_flags const flags, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::replace_re(strings, pattern, repl, maxrepl, rmm::cuda_stream_default, mr); + return detail::replace_re( + strings, pattern, replacement, max_replace_count, flags, rmm::cuda_stream_default, mr); } } // namespace strings diff --git a/cpp/tests/strings/replace_regex_tests.cpp b/cpp/tests/strings/replace_regex_tests.cpp index 16308265a9b..eac06fa4588 100644 --- a/cpp/tests/strings/replace_regex_tests.cpp +++ b/cpp/tests/strings/replace_regex_tests.cpp @@ -133,6 +133,58 @@ TEST_F(StringsReplaceRegexTest, WithEmptyPattern) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, strings); } +TEST_F(StringsReplaceRegexTest, MultiReplacement) +{ + cudf::test::strings_column_wrapper input({"aba bcd aba", "abababa abababa"}); + auto results = + cudf::strings::replace_re(cudf::strings_column_view(input), "aba", cudf::string_scalar("_"), 2); + cudf::test::strings_column_wrapper expected({"_ bcd _", "_b_ abababa"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected); + results = + cudf::strings::replace_re(cudf::strings_column_view(input), "aba", cudf::string_scalar(""), 0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, input); +} + +TEST_F(StringsReplaceRegexTest, Multiline) +{ + auto const multiline = cudf::strings::regex_flags::MULTILINE; + + cudf::test::strings_column_wrapper input({"bcd\naba\nefg", "aba\naba abab\naba", "aba"}); + auto sv = cudf::strings_column_view(input); + + // single-replace + auto results = + cudf::strings::replace_re(sv, "^aba$", cudf::string_scalar("_"), std::nullopt, multiline); + cudf::test::strings_column_wrapper expected_ml({"bcd\n_\nefg", "_\naba abab\n_", "_"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_ml); + + results = cudf::strings::replace_re(sv, "^aba$", cudf::string_scalar("_")); + cudf::test::strings_column_wrapper expected({"bcd\naba\nefg", "aba\naba abab\naba", "_"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected); + + // multi-replace + std::vector patterns({"aba$", "^aba"}); + cudf::test::strings_column_wrapper repls({">", "<"}); + results = cudf::strings::replace_re(sv, patterns, cudf::strings_column_view(repls), multiline); + cudf::test::strings_column_wrapper multi_expected_ml({"bcd\n>\nefg", ">\n< abab\n>", ">"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, multi_expected_ml); + + results = cudf::strings::replace_re(sv, patterns, cudf::strings_column_view(repls)); + cudf::test::strings_column_wrapper multi_expected({"bcd\naba\nefg", "<\naba abab\n>", ">"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, multi_expected); + + // backref-replace + results = cudf::strings::replace_with_backrefs(sv, "(^aba)", "[\\1]", multiline); + cudf::test::strings_column_wrapper br_expected_ml( + {"bcd\n[aba]\nefg", "[aba]\n[aba] abab\n[aba]", "[aba]"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, br_expected_ml); + + results = cudf::strings::replace_with_backrefs(sv, "(^aba)", "[\\1]"); + cudf::test::strings_column_wrapper br_expected( + {"bcd\naba\nefg", "[aba]\naba abab\naba", "[aba]"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, br_expected); +} + TEST_F(StringsReplaceRegexTest, ReplaceBackrefsRegexTest) { std::vector h_strings{"the quick brown fox jumps over the lazy dog",