Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support bracket syntax for cudf::strings::replace_with_backrefs group index values #8841

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions cpp/include/cudf/strings/replace_re.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -72,22 +72,24 @@ std::unique_ptr<column> replace_re(

/**
* @brief For each string, replaces any character sequence matching the given pattern
* using the repl template for back-references.
* using the replacement template for back-references.
*
* Any null string entries return corresponding null output column entries.
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @throw cudf::logic_error if capture index values in `replacement` are not in range 1-99
*
* @param strings Strings instance for this operation.
* @param pattern The regular expression patterns to search within each string.
* @param repl The replacement template for creating the output string.
* @param replacement The replacement template for creating the output string.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column.
*/
std::unique_ptr<column> replace_with_backrefs(
strings_column_view const& strings,
std::string const& pattern,
std::string const& repl,
std::string const& replacement,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace strings
Expand Down
60 changes: 39 additions & 21 deletions cpp/src/strings/replace/backref_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,39 +37,57 @@ namespace strings {
namespace detail {
namespace {

/**
* @brief Return the capturing group index pattern to use with the given replacement string.
*
* Only two patterns are supported at this time `\d` and `${d}` where `d` is an integer in
* the range 1-99. The `\d` pattern is returned by default unless no `\d` pattern is found in
* the `repl` string,
*
* Reference: https://www.regular-expressions.info/refreplacebackref.html
*/
std::string get_backref_pattern(std::string const& repl)
{
std::string const backslash_pattern = "\\\\(\\d+)";
std::string const bracket_pattern = "\\$\\{(\\d+)\\}";
std::smatch m;
return std::regex_search(repl, m, std::regex(backslash_pattern)) ? backslash_pattern
: bracket_pattern;
}
/**
* @brief Parse the back-ref index and position values from a given replace format.
*
* The backref numbers are expected to be 1-based.
* The back-ref numbers are expected to be 1-based.
*
* Returns a modified string without back-ref indicators and a vector of back-ref
* byte position pairs. These are used by the device code to build the output
* string by placing the captured group elements into the replace format.
*
* Returns a modified string without back-ref indicators and a vector of backref
* byte position pairs.
* ```
* Example:
* for input string: 'hello \2 and \1'
* the returned pairs: (2,6),(1,11)
* returned string is: 'hello and '
* ```
* For example, for input string 'hello \2 and \1' the returned `backref_type` vector
* contains `[(2,6),(1,11)]` and the returned string is 'hello and '.
*/
std::pair<std::string, std::vector<backref_type>> parse_backrefs(std::string const& repl)
{
std::vector<backref_type> backrefs;
std::string str = repl; // make a modifiable copy
std::smatch m;
std::regex ex("(\\\\\\d+)"); // this searches for backslash-number(s); example "\1"
std::string rtn; // result without refs
std::regex ex(get_backref_pattern(repl));
std::string rtn;
size_type byte_offset = 0;
while (std::regex_search(str, m, ex)) {
if (m.size() == 0) break;
std::string const backref = m[0];
size_type const position = static_cast<size_type>(m.position(0));
size_type const length = static_cast<size_type>(backref.length());
while (std::regex_search(str, m, ex) && !m.empty()) {
// parse the back-ref index number
size_type const index = static_cast<size_type>(std::atoi(std::string{m[1]}.c_str()));
CUDF_EXPECTS(index > 0 && index < 100, "Group index numbers must be in the range 1-99");

// store the new byte offset and index value
size_type const position = static_cast<size_type>(m.position(0));
byte_offset += position;
size_type const index = std::atoi(backref.c_str() + 1); // back-ref index number
CUDF_EXPECTS(index > 0, "Back-reference numbers must be greater than 0");
rtn += str.substr(0, position);
str = str.substr(position + length);
backrefs.push_back({index, byte_offset});

// update the output string
rtn += str.substr(0, position);
// remove the back-ref pattern to continue parsing
str = str.substr(position + static_cast<size_type>(m.length(0)));
}
if (!str.empty()) // add the remainder
rtn += str; // of the string
Expand All @@ -96,7 +114,7 @@ std::unique_ptr<column> replace_with_backrefs(
auto d_prog = reprog_device::create(pattern, get_character_flags_table(), strings.size(), stream);
auto const regex_insts = d_prog->insts_counts();

// parse the repl string for backref indicators
// parse the repl string for back-ref indicators
auto const parse_result = parse_backrefs(repl);
rmm::device_uvector<backref_type> backrefs(parse_result.second.size(), stream);
CUDA_TRY(cudaMemcpyAsync(backrefs.data(),
Expand Down
25 changes: 25 additions & 0 deletions cpp/tests/strings/replace_regex_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ TEST_F(StringsReplaceTests, ReplaceBackrefsRegexTest)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

TEST_F(StringsReplaceTests, ReplaceBackrefsRegexAltIndexPatternTest)
{
cudf::test::strings_column_wrapper strings({"12-3 34-5 67-89", "0-99: 777-888:: 5673-0"});
auto strings_view = cudf::strings_column_view(strings);

std::string pattern = "(\\d+)-(\\d+)";
std::string repl_template = "${2} X ${1}0";
auto results = cudf::strings::replace_with_backrefs(strings_view, pattern, repl_template);

cudf::test::strings_column_wrapper expected(
{"3 X 120 5 X 340 89 X 670", "99 X 00: 888 X 7770:: 0 X 56730"});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

TEST_F(StringsReplaceTests, ReplaceBackrefsRegexReversedTest)
{
cudf::test::strings_column_wrapper strings(
Expand Down Expand Up @@ -203,6 +217,17 @@ TEST_F(StringsReplaceTests, BackrefWithGreedyQuantifier)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

TEST_F(StringsReplaceTests, ReplaceBackrefsRegexErrorTest)
{
cudf::test::strings_column_wrapper strings({"this string left intentionally blank"});
auto view = cudf::strings_column_view(strings);

EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w)", "\\0"), cudf::logic_error);
EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w)", "\\123"), cudf::logic_error);
EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "", "\\1"), cudf::logic_error);
EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w)", ""), cudf::logic_error);
}

TEST_F(StringsReplaceTests, MediumReplaceRegex)
{
// This results in 95 regex instructions and falls in the 'medium' range.
Expand Down