diff --git a/cpp/include/cudf/strings/replace_re.hpp b/cpp/include/cudf/strings/replace_re.hpp index 28ab19e53d9..087d1a94603 100644 --- a/cpp/include/cudf/strings/replace_re.hpp +++ b/cpp/include/cudf/strings/replace_re.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,22 +72,24 @@ std::unique_ptr replace_re( /** * @brief For each string, replaces any character sequence matching the given pattern - * using the repl template for back-references. + * using the replacement template for back-references. * * Any null string entries return corresponding null output column entries. * * See the @ref md_regex "Regex Features" page for details on patterns supported by this API. * + * @throw cudf::logic_error if capture index values in `replacement` are not in range 1-99 + * * @param strings Strings instance for this operation. * @param pattern The regular expression patterns to search within each string. - * @param repl The replacement template for creating the output string. + * @param replacement The replacement template for creating the output string. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New strings column. */ std::unique_ptr replace_with_backrefs( strings_column_view const& strings, std::string const& pattern, - std::string const& repl, + std::string const& replacement, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); } // namespace strings diff --git a/cpp/src/strings/replace/backref_re.cu b/cpp/src/strings/replace/backref_re.cu index 462efedffe5..5f7b195e8f9 100644 --- a/cpp/src/strings/replace/backref_re.cu +++ b/cpp/src/strings/replace/backref_re.cu @@ -37,39 +37,57 @@ namespace strings { namespace detail { namespace { +/** + * @brief Return the capturing group index pattern to use with the given replacement string. + * + * Only two patterns are supported at this time `\d` and `${d}` where `d` is an integer in + * the range 1-99. The `\d` pattern is returned by default unless no `\d` pattern is found in + * the `repl` string, + * + * Reference: https://www.regular-expressions.info/refreplacebackref.html + */ +std::string get_backref_pattern(std::string const& repl) +{ + std::string const backslash_pattern = "\\\\(\\d+)"; + std::string const bracket_pattern = "\\$\\{(\\d+)\\}"; + std::smatch m; + return std::regex_search(repl, m, std::regex(backslash_pattern)) ? backslash_pattern + : bracket_pattern; +} /** * @brief Parse the back-ref index and position values from a given replace format. * - * The backref numbers are expected to be 1-based. + * The back-ref numbers are expected to be 1-based. + * + * Returns a modified string without back-ref indicators and a vector of back-ref + * byte position pairs. These are used by the device code to build the output + * string by placing the captured group elements into the replace format. * - * Returns a modified string without back-ref indicators and a vector of backref - * byte position pairs. - * ``` - * Example: - * for input string: 'hello \2 and \1' - * the returned pairs: (2,6),(1,11) - * returned string is: 'hello and ' - * ``` + * For example, for input string 'hello \2 and \1' the returned `backref_type` vector + * contains `[(2,6),(1,11)]` and the returned string is 'hello and '. */ std::pair> parse_backrefs(std::string const& repl) { std::vector backrefs; std::string str = repl; // make a modifiable copy std::smatch m; - std::regex ex("(\\\\\\d+)"); // this searches for backslash-number(s); example "\1" - std::string rtn; // result without refs + std::regex ex(get_backref_pattern(repl)); + std::string rtn; size_type byte_offset = 0; - while (std::regex_search(str, m, ex)) { - if (m.size() == 0) break; - std::string const backref = m[0]; - size_type const position = static_cast(m.position(0)); - size_type const length = static_cast(backref.length()); + while (std::regex_search(str, m, ex) && !m.empty()) { + // parse the back-ref index number + size_type const index = static_cast(std::atoi(std::string{m[1]}.c_str())); + CUDF_EXPECTS(index > 0 && index < 100, "Group index numbers must be in the range 1-99"); + + // store the new byte offset and index value + size_type const position = static_cast(m.position(0)); byte_offset += position; - size_type const index = std::atoi(backref.c_str() + 1); // back-ref index number - CUDF_EXPECTS(index > 0, "Back-reference numbers must be greater than 0"); - rtn += str.substr(0, position); - str = str.substr(position + length); backrefs.push_back({index, byte_offset}); + + // update the output string + rtn += str.substr(0, position); + // remove the back-ref pattern to continue parsing + str = str.substr(position + static_cast(m.length(0))); } if (!str.empty()) // add the remainder rtn += str; // of the string @@ -96,7 +114,7 @@ std::unique_ptr replace_with_backrefs( auto d_prog = reprog_device::create(pattern, get_character_flags_table(), strings.size(), stream); auto const regex_insts = d_prog->insts_counts(); - // parse the repl string for backref indicators + // parse the repl string for back-ref indicators auto const parse_result = parse_backrefs(repl); rmm::device_uvector backrefs(parse_result.second.size(), stream); CUDA_TRY(cudaMemcpyAsync(backrefs.data(), diff --git a/cpp/tests/strings/replace_regex_tests.cpp b/cpp/tests/strings/replace_regex_tests.cpp index a2486d60051..1f01f0f1429 100644 --- a/cpp/tests/strings/replace_regex_tests.cpp +++ b/cpp/tests/strings/replace_regex_tests.cpp @@ -167,6 +167,20 @@ TEST_F(StringsReplaceTests, ReplaceBackrefsRegexTest) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); } +TEST_F(StringsReplaceTests, ReplaceBackrefsRegexAltIndexPatternTest) +{ + cudf::test::strings_column_wrapper strings({"12-3 34-5 67-89", "0-99: 777-888:: 5673-0"}); + auto strings_view = cudf::strings_column_view(strings); + + std::string pattern = "(\\d+)-(\\d+)"; + std::string repl_template = "${2} X ${1}0"; + auto results = cudf::strings::replace_with_backrefs(strings_view, pattern, repl_template); + + cudf::test::strings_column_wrapper expected( + {"3 X 120 5 X 340 89 X 670", "99 X 00: 888 X 7770:: 0 X 56730"}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); +} + TEST_F(StringsReplaceTests, ReplaceBackrefsRegexReversedTest) { cudf::test::strings_column_wrapper strings( @@ -203,6 +217,17 @@ TEST_F(StringsReplaceTests, BackrefWithGreedyQuantifier) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); } +TEST_F(StringsReplaceTests, ReplaceBackrefsRegexErrorTest) +{ + cudf::test::strings_column_wrapper strings({"this string left intentionally blank"}); + auto view = cudf::strings_column_view(strings); + + EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w)", "\\0"), cudf::logic_error); + EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w)", "\\123"), cudf::logic_error); + EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "", "\\1"), cudf::logic_error); + EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w)", ""), cudf::logic_error); +} + TEST_F(StringsReplaceTests, MediumReplaceRegex) { // This results in 95 regex instructions and falls in the 'medium' range.