Skip to content

Commit

Permalink
Add regex_flags parameter to strings replace_re functions (#9878)
Browse files Browse the repository at this point in the history
Closes #9845 

Adds a `cudf::strings::regex_flags` parameter to the `cudf::strings::replace_re` functions so the matching logic will be the same as for `cudf::strings::contains_re` which already has this parameter.

This is a breaking change since it adds this new parameter and changes the default behavior. The previous default behavior is equivalent to specifying the `regex_flags::MULTILINE` flag now to be consistent with the default behavior of `contains_re`.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Mike Wilson (https://github.com/hyperbolic2346)

URL: #9878
  • Loading branch information
davidwendt authored Dec 15, 2021
1 parent 38631a6 commit db9aef8
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 101 deletions.
24 changes: 16 additions & 8 deletions cpp/include/cudf/strings/replace_re.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <cudf/column/column.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/strings/regex/flags.hpp>
#include <cudf/strings/strings_column_view.hpp>

namespace cudf {
Expand All @@ -37,37 +38,42 @@ namespace strings {
*
* @param strings Strings instance for this operation.
* @param pattern The regular expression pattern to search within each string.
* @param repl The string used to replace the matched sequence in each string.
* @param replacement The string used to replace the matched sequence in each string.
* Default is an empty string.
* @param maxrepl The maximum number of times to replace the matched pattern within each string.
* @param max_replace_count The maximum number of times to replace the matched pattern
* within each string. Default replaces every substring that is matched.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column.
*/
std::unique_ptr<column> replace_re(
strings_column_view const& strings,
std::string const& pattern,
string_scalar const& repl = string_scalar(""),
size_type maxrepl = -1,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
string_scalar const& replacement = string_scalar(""),
std::optional<size_type> max_replace_count = std::nullopt,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief For each string, replaces any character sequence matching the given patterns
* with the corresponding string in the repls column.
* with the corresponding string in the `replacements` column.
*
* Any null string entries return corresponding null output column entries.
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param strings Strings instance for this operation.
* @param patterns The regular expression patterns to search within each string.
* @param repls The strings used for replacement.
* @param replacements The strings used for replacement.
* @param flags Regex flags for interpreting special characters in the patterns.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column.
*/
std::unique_ptr<column> replace_re(
strings_column_view const& strings,
std::vector<std::string> const& patterns,
strings_column_view const& repls,
strings_column_view const& replacements,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -83,13 +89,15 @@ std::unique_ptr<column> replace_re(
* @param strings Strings instance for this operation.
* @param pattern The regular expression patterns to search within each string.
* @param replacement The replacement template for creating the output string.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column.
*/
std::unique_ptr<column> replace_with_backrefs(
strings_column_view const& strings,
std::string const& pattern,
std::string const& replacement,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace strings
Expand Down
16 changes: 10 additions & 6 deletions cpp/src/strings/replace/backref_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,24 @@ std::pair<std::string, std::vector<backref_type>> parse_backrefs(std::string con
std::unique_ptr<column> replace_with_backrefs(
strings_column_view const& strings,
std::string const& pattern,
std::string const& repl,
std::string const& replacement,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
if (strings.is_empty()) return make_empty_column(type_id::STRING);

CUDF_EXPECTS(!pattern.empty(), "Parameter pattern must not be empty");
CUDF_EXPECTS(!repl.empty(), "Parameter repl must not be empty");
CUDF_EXPECTS(!replacement.empty(), "Parameter replacement must not be empty");

auto d_strings = column_device_view::create(strings.parent(), stream);
// compile regex into device object
auto d_prog = reprog_device::create(pattern, get_character_flags_table(), strings.size(), stream);
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings.size(), stream);
auto const regex_insts = d_prog->insts_counts();

// parse the repl string for back-ref indicators
auto const parse_result = parse_backrefs(repl);
auto const parse_result = parse_backrefs(replacement);
rmm::device_uvector<backref_type> backrefs =
cudf::detail::make_device_uvector_async(parse_result.second, stream);
string_scalar repl_scalar(parse_result.first, true, stream);
Expand Down Expand Up @@ -170,11 +172,13 @@ std::unique_ptr<column> replace_with_backrefs(

std::unique_ptr<column> replace_with_backrefs(strings_column_view const& strings,
std::string const& pattern,
std::string const& repl,
std::string const& replacement,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::replace_with_backrefs(strings, pattern, repl, rmm::cuda_stream_default, mr);
return detail::replace_with_backrefs(
strings, pattern, replacement, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
84 changes: 34 additions & 50 deletions cpp/src/strings/replace/multi_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/strings/replace_re.hpp>
#include <cudf/strings/string_view.cuh>
Expand Down Expand Up @@ -53,19 +54,21 @@ using found_range = thrust::pair<size_type, size_type>;
template <int stack_size>
struct replace_multi_regex_fn {
column_device_view const d_strings;
reprog_device* progs; // array of regex progs
size_type number_of_patterns;
found_range* d_found_ranges; // working array matched (begin,end) values
column_device_view const d_repls; // replacement strings
int32_t* d_offsets{}; // these are null when
char* d_chars{}; // only computing size
device_span<reprog_device const> progs; // array of regex progs
found_range* d_found_ranges; // working array matched (begin,end) values
column_device_view const d_repls; // replacement strings
int32_t* d_offsets{};
char* d_chars{};

__device__ void operator()(size_type idx)
{
if (d_strings.is_null(idx)) {
if (!d_chars) d_offsets[idx] = 0;
return;
}

auto const number_of_patterns = static_cast<size_type>(progs.size());

auto const d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length(); // number of characters in input string
auto nbytes = d_str.size_bytes(); // number of bytes in input string
Expand Down Expand Up @@ -129,7 +132,8 @@ struct replace_multi_regex_fn {
std::unique_ptr<column> replace_re(
strings_column_view const& strings,
std::vector<std::string> const& patterns,
strings_column_view const& repls,
strings_column_view const& replacements,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
Expand All @@ -138,31 +142,25 @@ std::unique_ptr<column> replace_re(
if (patterns.empty()) // no patterns; just return a copy
return std::make_unique<column>(strings.parent(), stream, mr);

CUDF_EXPECTS(!repls.has_nulls(), "Parameter repls must not have any nulls");
CUDF_EXPECTS(!replacements.has_nulls(), "Parameter replacements must not have any nulls");

auto d_strings = column_device_view::create(strings.parent(), stream);
auto d_repls = column_device_view::create(repls.parent(), stream);
auto d_flags = get_character_flags_table();
auto d_strings = column_device_view::create(strings.parent(), stream);
auto d_repls = column_device_view::create(replacements.parent(), stream);
auto d_char_table = get_character_flags_table();

// compile regexes into device objects
size_type regex_insts = 0;
std::vector<std::unique_ptr<reprog_device, std::function<void(reprog_device*)>>> h_progs;
thrust::host_vector<reprog_device> progs;
std::vector<reprog_device> progs;
for (auto itr = patterns.begin(); itr != patterns.end(); ++itr) {
auto prog = reprog_device::create(*itr, d_flags, strings_count, stream);
auto prog = reprog_device::create(*itr, flags, d_char_table, strings_count, stream);
regex_insts = std::max(regex_insts, prog->insts_counts());
progs.push_back(*prog);
h_progs.emplace_back(std::move(prog));
}

// copy all the reprog_device instances to a device memory array
rmm::device_buffer progs_buffer{sizeof(reprog_device) * progs.size(), stream};
CUDA_TRY(cudaMemcpyAsync(progs_buffer.data(),
progs.data(),
progs.size() * sizeof(reprog_device),
cudaMemcpyHostToDevice,
stream.value()));
reprog_device* d_progs = reinterpret_cast<reprog_device*>(progs_buffer.data());
auto d_progs = cudf::detail::make_device_uvector_async(progs, stream);

// create working buffer for ranges pairs
rmm::device_uvector<found_range> found_ranges(patterns.size() * strings_count, stream);
Expand All @@ -172,34 +170,19 @@ std::unique_ptr<column> replace_re(
auto children = [&] {
// Each invocation is predicated on the stack size which is dependent on the number of regex
// instructions
if (regex_insts <= RX_SMALL_INSTS)
return make_strings_children(
replace_multi_regex_fn<RX_STACK_SMALL>{
*d_strings, d_progs, static_cast<size_type>(progs.size()), d_found_ranges, *d_repls},
strings_count,
stream,
mr);
else if (regex_insts <= RX_MEDIUM_INSTS)
return make_strings_children(
replace_multi_regex_fn<RX_STACK_MEDIUM>{
*d_strings, d_progs, static_cast<size_type>(progs.size()), d_found_ranges, *d_repls},
strings_count,
stream,
mr);
else if (regex_insts <= RX_LARGE_INSTS)
return make_strings_children(
replace_multi_regex_fn<RX_STACK_LARGE>{
*d_strings, d_progs, static_cast<size_type>(progs.size()), d_found_ranges, *d_repls},
strings_count,
stream,
mr);
else
return make_strings_children(
replace_multi_regex_fn<RX_STACK_ANY>{
*d_strings, d_progs, static_cast<size_type>(progs.size()), d_found_ranges, *d_repls},
strings_count,
stream,
mr);
if (regex_insts <= RX_SMALL_INSTS) {
replace_multi_regex_fn<RX_STACK_SMALL> fn{*d_strings, d_progs, d_found_ranges, *d_repls};
return make_strings_children(fn, strings_count, stream, mr);
} else if (regex_insts <= RX_MEDIUM_INSTS) {
replace_multi_regex_fn<RX_STACK_MEDIUM> fn{*d_strings, d_progs, d_found_ranges, *d_repls};
return make_strings_children(fn, strings_count, stream, mr);
} else if (regex_insts <= RX_LARGE_INSTS) {
replace_multi_regex_fn<RX_STACK_LARGE> fn{*d_strings, d_progs, d_found_ranges, *d_repls};
return make_strings_children(fn, strings_count, stream, mr);
} else {
replace_multi_regex_fn<RX_STACK_ANY> fn{*d_strings, d_progs, d_found_ranges, *d_repls};
return make_strings_children(fn, strings_count, stream, mr);
}
}();

return make_strings_column(strings_count,
Expand All @@ -215,11 +198,12 @@ std::unique_ptr<column> replace_re(

std::unique_ptr<column> replace_re(strings_column_view const& strings,
std::vector<std::string> const& patterns,
strings_column_view const& repls,
strings_column_view const& replacements,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::replace_re(strings, patterns, repls, rmm::cuda_stream_default, mr);
return detail::replace_re(strings, patterns, replacements, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
68 changes: 31 additions & 37 deletions cpp/src/strings/replace/replace_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct replace_regex_fn {
column_device_view const d_strings;
reprog_device prog;
string_view const d_repl;
size_type maxrepl;
size_type const maxrepl;
int32_t* d_offsets{};
char* d_chars{};

Expand Down Expand Up @@ -102,56 +102,48 @@ struct replace_regex_fn {
std::unique_ptr<column> replace_re(
strings_column_view const& strings,
std::string const& pattern,
string_scalar const& repl = string_scalar(""),
size_type maxrepl = -1,
string_scalar const& replacement,
std::optional<size_type> max_replace_count,
regex_flags const flags,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto strings_count = strings.size();
if (strings_count == 0) return make_empty_column(type_id::STRING);

CUDF_EXPECTS(repl.is_valid(stream), "Parameter repl must be valid");
string_view d_repl(repl.data(), repl.size());
CUDF_EXPECTS(replacement.is_valid(stream), "Parameter replacement must be valid");
string_view d_repl(replacement.data(), replacement.size());

auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_strings = *strings_column;
// compile regex into device object
auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;
auto regex_insts = d_prog.insts_counts();
auto prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;
auto const regex_insts = d_prog.insts_counts();

// copy null mask
auto null_mask = cudf::detail::copy_bitmask(strings.parent(), stream, mr);
auto null_count = strings.null_count();
auto null_mask = cudf::detail::copy_bitmask(strings.parent(), stream, mr);
auto const null_count = strings.null_count();
auto const maxrepl = max_replace_count.value_or(-1);

// create child columns
auto children = [&] {
// Each invocation is predicated on the stack size which is dependent on the number of regex
// instructions
if (regex_insts <= RX_SMALL_INSTS)
return make_strings_children(
replace_regex_fn<RX_STACK_SMALL>{d_strings, d_prog, d_repl, maxrepl},
strings_count,
stream,
mr);
else if (regex_insts <= RX_MEDIUM_INSTS)
return make_strings_children(
replace_regex_fn<RX_STACK_MEDIUM>{d_strings, d_prog, d_repl, maxrepl},
strings_count,
stream,
mr);
else if (regex_insts <= RX_LARGE_INSTS)
return make_strings_children(
replace_regex_fn<RX_STACK_LARGE>{d_strings, d_prog, d_repl, maxrepl},
strings_count,
stream,
mr);
else
return make_strings_children(
replace_regex_fn<RX_STACK_ANY>{d_strings, d_prog, d_repl, maxrepl},
strings_count,
stream,
mr);
if (regex_insts <= RX_SMALL_INSTS) {
replace_regex_fn<RX_STACK_SMALL> fn{d_strings, d_prog, d_repl, maxrepl};
return make_strings_children(fn, strings_count, stream, mr);
} else if (regex_insts <= RX_MEDIUM_INSTS) {
replace_regex_fn<RX_STACK_MEDIUM> fn{d_strings, d_prog, d_repl, maxrepl};
return make_strings_children(fn, strings_count, stream, mr);
} else if (regex_insts <= RX_LARGE_INSTS) {
replace_regex_fn<RX_STACK_LARGE> fn{d_strings, d_prog, d_repl, maxrepl};
return make_strings_children(fn, strings_count, stream, mr);
} else {
replace_regex_fn<RX_STACK_ANY> fn{d_strings, d_prog, d_repl, maxrepl};
return make_strings_children(fn, strings_count, stream, mr);
}
}();

return make_strings_column(strings_count,
Expand All @@ -167,12 +159,14 @@ std::unique_ptr<column> replace_re(

std::unique_ptr<column> replace_re(strings_column_view const& strings,
std::string const& pattern,
string_scalar const& repl,
size_type maxrepl,
string_scalar const& replacement,
std::optional<size_type> max_replace_count,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::replace_re(strings, pattern, repl, maxrepl, rmm::cuda_stream_default, mr);
return detail::replace_re(
strings, pattern, replacement, max_replace_count, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
Loading

0 comments on commit db9aef8

Please sign in to comment.