Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-22.06' into revise-10min
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed May 13, 2022
2 parents 3297f3c + 4ad1e51 commit 49f60ae
Show file tree
Hide file tree
Showing 22 changed files with 156 additions and 126 deletions.
8 changes: 4 additions & 4 deletions cpp/include/cudf/strings/contains.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,7 +51,7 @@ namespace strings {
*/
std::unique_ptr<column> contains_re(
strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand All @@ -78,7 +78,7 @@ std::unique_ptr<column> contains_re(
*/
std::unique_ptr<column> matches_re(
strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand All @@ -105,7 +105,7 @@ std::unique_ptr<column> matches_re(
*/
std::unique_ptr<column> count_re(
strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/strings/extract.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace strings {
*/
std::unique_ptr<table> extract(
strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down Expand Up @@ -90,7 +90,7 @@ std::unique_ptr<table> extract(
*/
std::unique_ptr<column> extract_all_record(
strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/strings/findall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace strings {
*/
std::unique_ptr<table> findall(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down Expand Up @@ -90,7 +90,7 @@ std::unique_ptr<table> findall(
*/
std::unique_ptr<column> findall_record(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down
6 changes: 3 additions & 3 deletions cpp/include/cudf/strings/replace_re.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace strings {
*/
std::unique_ptr<column> replace_re(
strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
string_scalar const& replacement = string_scalar(""),
std::optional<size_type> max_replace_count = std::nullopt,
regex_flags const flags = regex_flags::DEFAULT,
Expand Down Expand Up @@ -98,8 +98,8 @@ std::unique_ptr<column> replace_re(
*/
std::unique_ptr<column> replace_with_backrefs(
strings_column_view const& strings,
std::string const& pattern,
std::string const& replacement,
std::string_view pattern,
std::string_view replacement,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down
8 changes: 4 additions & 4 deletions cpp/include/cudf/strings/split/split_re.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace strings {
*/
std::unique_ptr<table> split_re(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
size_type maxsplit = -1,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down Expand Up @@ -121,7 +121,7 @@ std::unique_ptr<table> split_re(
*/
std::unique_ptr<table> rsplit_re(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
size_type maxsplit = -1,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down Expand Up @@ -173,7 +173,7 @@ std::unique_ptr<table> rsplit_re(
*/
std::unique_ptr<column> split_record_re(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
size_type maxsplit = -1,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down Expand Up @@ -227,7 +227,7 @@ std::unique_ptr<column> split_record_re(
*/
std::unique_ptr<column> rsplit_record_re(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
size_type maxsplit = -1,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down
14 changes: 7 additions & 7 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct contains_fn {
};

std::unique_ptr<column> contains_impl(strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
bool const beginning_only,
rmm::cuda_stream_view stream,
Expand Down Expand Up @@ -85,7 +85,7 @@ std::unique_ptr<column> contains_impl(strings_column_view const& input,

std::unique_ptr<column> contains_re(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
Expand All @@ -95,7 +95,7 @@ std::unique_ptr<column> contains_re(

std::unique_ptr<column> matches_re(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
Expand All @@ -105,7 +105,7 @@ std::unique_ptr<column> matches_re(

std::unique_ptr<column> count_re(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
Expand All @@ -128,7 +128,7 @@ std::unique_ptr<column> count_re(
// external APIs

std::unique_ptr<column> contains_re(strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
Expand All @@ -137,7 +137,7 @@ std::unique_ptr<column> contains_re(strings_column_view const& strings,
}

std::unique_ptr<column> matches_re(strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
Expand All @@ -146,7 +146,7 @@ std::unique_ptr<column> matches_re(strings_column_view const& strings,
}

std::unique_ptr<column> count_re(strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/strings/extract/extract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct extract_fn {

//
std::unique_ptr<table> extract(strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
Expand Down Expand Up @@ -130,7 +130,7 @@ std::unique_ptr<table> extract(strings_column_view const& input,
// external API

std::unique_ptr<table> extract(strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/strings/extract/extract_all.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ struct extract_fn {
*/
std::unique_ptr<column> extract_all_record(
strings_column_view const& input,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
Expand Down Expand Up @@ -165,7 +165,7 @@ std::unique_ptr<column> extract_all_record(
// external API

std::unique_ptr<column> extract_all_record(strings_column_view const& strings,
std::string const& pattern,
std::string_view pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
Expand Down
124 changes: 61 additions & 63 deletions cpp/src/strings/regex/regcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

#include <algorithm>
#include <array>
#include <cstring>
#include <numeric>
#include <stack>
#include <string>

namespace cudf {
namespace strings {
Expand Down Expand Up @@ -862,8 +864,7 @@ class regex_compiler {
; // "unmatched left paren";
/* points to first and only operand */
m_prog.set_start_inst(andstack[andstack.size() - 1].id_first);
m_prog.optimize1();
m_prog.optimize2();
m_prog.finalize();
m_prog.check_for_errors();
m_prog.set_groups_count(cursubid);
}
Expand All @@ -880,81 +881,78 @@ reprog reprog::create_from(std::string_view pattern, regex_flags const flags)
return rtn;
}

//
void reprog::optimize1()
void reprog::finalize()
{
// Treat non-capturing LBRAs/RBRAs as NOOP
for (int i = 0; i < static_cast<int>(_insts.size()); i++) {
if (_insts[i].type == LBRA || _insts[i].type == RBRA) {
if (_insts[i].u1.subid < 1) { _insts[i].type = NOP; }
collapse_nops();
build_start_ids();
}

void reprog::collapse_nops()
{
// treat non-capturing LBRAs/RBRAs as NOP
std::transform(_insts.begin(), _insts.end(), _insts.begin(), [](auto inst) {
if ((inst.type == LBRA || inst.type == RBRA) && (inst.u1.subid < 1)) { inst.type = NOP; }
return inst;
});

// functor for finding the next valid op
auto find_next_op = [insts = _insts](int id) {
while (insts[id].type == NOP) {
id = insts[id].u2.next_id;
}
}
return id;
};

// get rid of NOP chains
for (int i = 0; i < insts_count(); i++) {
if (_insts[i].type != NOP) {
{
int target_id = _insts[i].u2.next_id;
while (_insts[target_id].type == NOP)
target_id = _insts[target_id].u2.next_id;
_insts[i].u2.next_id = target_id;
}
if (_insts[i].type == OR) {
int target_id = _insts[i].u1.right_id;
while (_insts[target_id].type == NOP)
target_id = _insts[target_id].u2.next_id;
_insts[i].u1.right_id = target_id;
}
// create new routes around NOP chains
std::transform(_insts.begin(), _insts.end(), _insts.begin(), [find_next_op](auto inst) {
if (inst.type != NOP) {
inst.u2.next_id = find_next_op(inst.u2.next_id);
if (inst.type == OR) { inst.u1.right_id = find_next_op(inst.u1.right_id); }
}
}
// skip NOPs from the beginning
{
int target_id = _startinst_id;
while (_insts[target_id].type == NOP)
target_id = _insts[target_id].u2.next_id;
_startinst_id = target_id;
}
// actually remove the no-ops
return inst;
});

// find starting op
_startinst_id = find_next_op(_startinst_id);

// build a map of op ids
// these are used to fix up the ids after the NOPs are removed
std::vector<int> id_map(insts_count());
int j = 0; // compact the ops (non no-ops)
for (int i = 0; i < insts_count(); i++) {
id_map[i] = j;
if (_insts[i].type != NOP) {
_insts[j] = _insts[i];
j++;
}
}
_insts.resize(j);
// fix up the ORs
for (int i = 0; i < insts_count(); i++) {
{
int target_id = _insts[i].u2.next_id;
_insts[i].u2.next_id = id_map[target_id];
}
if (_insts[i].type == OR) {
int target_id = _insts[i].u1.right_id;
_insts[i].u1.right_id = id_map[target_id];
}
}
// set the new start id
std::transform_exclusive_scan(
_insts.begin(), _insts.end(), id_map.begin(), 0, std::plus<int>{}, [](auto inst) {
return static_cast<int>(inst.type != NOP);
});

// remove the NOP instructions
auto end = std::remove_if(_insts.begin(), _insts.end(), [](auto i) { return i.type == NOP; });
_insts.resize(std::distance(_insts.begin(), end));

// fix up the ids on the remaining instructions using the id_map
std::transform(_insts.begin(), _insts.end(), _insts.begin(), [id_map](auto inst) {
inst.u2.next_id = id_map[inst.u2.next_id];
if (inst.type == OR) { inst.u1.right_id = id_map[inst.u1.right_id]; }
return inst;
});

// fix up the start instruction id too
_startinst_id = id_map[_startinst_id];
}

// expand leading ORs to multiple startinst_ids
void reprog::optimize2()
void reprog::build_start_ids()
{
_startinst_ids.clear();
std::vector<int> stack;
stack.push_back(_startinst_id);
while (!stack.empty()) {
int id = stack.back();
stack.pop_back();
std::stack<int> ids;
ids.push(_startinst_id);
while (!ids.empty()) {
int id = ids.top();
ids.pop();
const reinst& inst = _insts[id];
if (inst.type == OR) {
if (inst.u2.left_id != id) // prevents infinite while-loop here
stack.push_back(inst.u2.left_id);
ids.push(inst.u2.left_id);
if (inst.u1.right_id != id) // prevents infinite while-loop here
stack.push_back(inst.u1.right_id);
ids.push(inst.u1.right_id);
} else {
_startinst_ids.push_back(id);
}
Expand Down
Loading

0 comments on commit 49f60ae

Please sign in to comment.