Skip to content

Commit

Permalink
Adjust the valid range of group index for replace_with_backrefs (#10530)
Browse files Browse the repository at this point in the history
Current PR is to adjust to the valid range of group index for cuDF API `cudf::strings::replace_with_backrefs`.

1.  enable 0 as group index
For now, the range of group index starts with 1, which doesn't include the special value 0. Zero-value as backref index usually refers the entire matching pattern. So does cuDF regexp system.
Therefore, what we only need to do is lifting the restrictions to allow zero-value passed as the group index of back references.
Example of zero-value index:
input: `aa-11 b2b-345`
pattern: `([a-z]+)-([0-9]+)`
replacement: `${0}:${1}:${2};`
output: ```aa-11:aa:11; b2b-345:b:345;```

2. group index should not exceed group count
For now, group indices can exceed group count. The exceeding ones will end up to be empty string. IMHO, it is better to throw an exception under this circumstance instead of ignoring these overflow indices.

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - David Wendt (https://github.com/davidwendt)

URL: #10530
  • Loading branch information
sperlingxx authored Mar 31, 2022
1 parent 1355191 commit bc8f578
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 7 deletions.
5 changes: 3 additions & 2 deletions cpp/include/cudf/strings/replace_re.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -86,7 +86,8 @@ std::unique_ptr<column> replace_re(
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @throw cudf::logic_error if capture index values in `replacement` are not in range 1-99
* @throw cudf::logic_error if capture index values in `replacement` are not in range 0-99, and also
* if the index exceeds the group count specified in the pattern
*
* @param strings Strings instance for this operation.
* @param pattern The regular expression patterns to search within each string.
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/strings/replace/backref_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ std::string get_backref_pattern(std::string const& repl)
* For example, for input string 'hello \2 and \1' the returned `backref_type` vector
* contains `[(2,6),(1,11)]` and the returned string is 'hello and '.
*/
std::pair<std::string, std::vector<backref_type>> parse_backrefs(std::string const& repl)
std::pair<std::string, std::vector<backref_type>> parse_backrefs(std::string const& repl,
int const group_count)
{
std::vector<backref_type> backrefs;
std::string str = repl; // make a modifiable copy
Expand All @@ -79,7 +80,8 @@ std::pair<std::string, std::vector<backref_type>> parse_backrefs(std::string con
while (std::regex_search(str, m, ex) && !m.empty()) {
// parse the back-ref index number
size_type const index = static_cast<size_type>(std::atoi(std::string{m[1]}.c_str()));
CUDF_EXPECTS(index > 0 && index < 100, "Group index numbers must be in the range 1-99");
CUDF_EXPECTS(index >= 0 && index <= group_count,
"Group index numbers must be in the range 0 to group count");

// store the new byte offset and index value
size_type const position = static_cast<size_type>(m.position(0));
Expand Down Expand Up @@ -146,7 +148,8 @@ std::unique_ptr<column> replace_with_backrefs(
reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream);

// parse the repl string for back-ref indicators
auto const parse_result = parse_backrefs(replacement);
auto group_count = std::min(99, d_prog->group_counts()); // group count should NOT exceed 99
auto const parse_result = parse_backrefs(replacement, group_count);
rmm::device_uvector<backref_type> backrefs =
cudf::detail::make_device_uvector_async(parse_result.second, stream);
string_scalar repl_scalar(parse_result.first, true, stream);
Expand Down
23 changes: 21 additions & 2 deletions cpp/tests/strings/replace_regex_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,32 @@ TEST_F(StringsReplaceRegexTest, BackrefWithGreedyQuantifier)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

TEST_F(StringsReplaceRegexTest, ReplaceBackrefsRegexZeroIndexTest)
{
cudf::test::strings_column_wrapper strings(
{"TEST123", "TEST1TEST2", "TEST2-TEST1122", "TEST1-TEST-T", "TES3"});
auto strings_view = cudf::strings_column_view(strings);
std::string pattern = "(TEST)(\\d+)";
std::string repl_template = "${0}: ${1}, ${2}; ";
auto results = cudf::strings::replace_with_backrefs(strings_view, pattern, repl_template);

cudf::test::strings_column_wrapper expected({
"TEST123: TEST, 123; ",
"TEST1: TEST, 1; TEST2: TEST, 2; ",
"TEST2: TEST, 2; -TEST1122: TEST, 1122; ",
"TEST1: TEST, 1; -TEST-T",
"TES3",
});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

TEST_F(StringsReplaceRegexTest, ReplaceBackrefsRegexErrorTest)
{
cudf::test::strings_column_wrapper strings({"this string left intentionally blank"});
auto view = cudf::strings_column_view(strings);

EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w)", "\\0"), cudf::logic_error);
EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w)", "\\123"), cudf::logic_error);
// group index(3) exceeds the group count(2)
EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w).(\\w)", "\\3"), cudf::logic_error);
EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "", "\\1"), cudf::logic_error);
EXPECT_THROW(cudf::strings::replace_with_backrefs(view, "(\\w)", ""), cudf::logic_error);
}
Expand Down
16 changes: 16 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4987,6 +4987,22 @@ void testStringReplaceWithBackrefs() {
assertColumnsAreEqual(expected, actual);
}

// test zero as group index
try (ColumnVector v = ColumnVector.fromStrings("aa-11 b2b-345", "aa-11a 1c-2b2 b2-c3", "11-aa", null);
ColumnVector expected = ColumnVector.fromStrings("aa-11:aa:11; b2b-345:b:345;",
"aa-11:aa:11;a 1c-2:c:2;b2 b2-c3", "11-aa", null);
ColumnVector actual = v.stringReplaceWithBackrefs(
"([a-z]+)-([0-9]+)", "${0}:${1}:${2};")) {
assertColumnsAreEqual(expected, actual);
}

// group index exceeds group count
assertThrows(CudfException.class, () -> {
try (ColumnVector v = ColumnVector.fromStrings("ABC123defgh");
ColumnVector r = v.stringReplaceWithBackrefs("([A-Z]+)([0-9]+)([a-z]+)", "\\4")) {
}
});

}

@Test
Expand Down

0 comments on commit bc8f578

Please sign in to comment.