Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup regex compile optimize functions #10825

Merged
merged 3 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 61 additions & 63 deletions cpp/src/strings/regex/regcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

#include <algorithm>
#include <array>
#include <cstring>
#include <numeric>
#include <stack>
#include <string>

namespace cudf {
namespace strings {
Expand Down Expand Up @@ -862,8 +864,7 @@ class regex_compiler {
; // "unmatched left paren";
/* points to first and only operand */
m_prog.set_start_inst(andstack[andstack.size() - 1].id_first);
m_prog.optimize1();
m_prog.optimize2();
m_prog.finalize();
m_prog.check_for_errors();
m_prog.set_groups_count(cursubid);
}
Expand All @@ -880,81 +881,78 @@ reprog reprog::create_from(std::string_view pattern, regex_flags const flags)
return rtn;
}

//
void reprog::optimize1()
void reprog::finalize()
{
// Treat non-capturing LBRAs/RBRAs as NOOP
for (int i = 0; i < static_cast<int>(_insts.size()); i++) {
if (_insts[i].type == LBRA || _insts[i].type == RBRA) {
if (_insts[i].u1.subid < 1) { _insts[i].type = NOP; }
collapse_nops();
build_start_ids();
}

void reprog::collapse_nops()
{
// treat non-capturing LBRAs/RBRAs as NOP
std::transform(_insts.begin(), _insts.end(), _insts.begin(), [](auto inst) {
if ((inst.type == LBRA || inst.type == RBRA) && (inst.u1.subid < 1)) { inst.type = NOP; }
return inst;
});

// functor for finding the next valid op
auto find_next_op = [insts = _insts](int id) {
while (insts[id].type == NOP) {
id = insts[id].u2.next_id;
}
}
return id;
};

// get rid of NOP chains
for (int i = 0; i < insts_count(); i++) {
if (_insts[i].type != NOP) {
{
int target_id = _insts[i].u2.next_id;
while (_insts[target_id].type == NOP)
target_id = _insts[target_id].u2.next_id;
_insts[i].u2.next_id = target_id;
}
if (_insts[i].type == OR) {
int target_id = _insts[i].u1.right_id;
while (_insts[target_id].type == NOP)
target_id = _insts[target_id].u2.next_id;
_insts[i].u1.right_id = target_id;
}
// create new routes around NOP chains
std::transform(_insts.begin(), _insts.end(), _insts.begin(), [find_next_op](auto inst) {
if (inst.type != NOP) {
inst.u2.next_id = find_next_op(inst.u2.next_id);
if (inst.type == OR) { inst.u1.right_id = find_next_op(inst.u1.right_id); }
}
}
// skip NOPs from the beginning
{
int target_id = _startinst_id;
while (_insts[target_id].type == NOP)
target_id = _insts[target_id].u2.next_id;
_startinst_id = target_id;
}
// actually remove the no-ops
return inst;
});

// find starting op
_startinst_id = find_next_op(_startinst_id);

// build a map of op ids
// these are used to fix up the ids after the NOPs are removed
std::vector<int> id_map(insts_count());
int j = 0; // compact the ops (non no-ops)
for (int i = 0; i < insts_count(); i++) {
id_map[i] = j;
if (_insts[i].type != NOP) {
_insts[j] = _insts[i];
j++;
}
}
_insts.resize(j);
// fix up the ORs
for (int i = 0; i < insts_count(); i++) {
{
int target_id = _insts[i].u2.next_id;
_insts[i].u2.next_id = id_map[target_id];
}
if (_insts[i].type == OR) {
int target_id = _insts[i].u1.right_id;
_insts[i].u1.right_id = id_map[target_id];
}
}
// set the new start id
std::transform_exclusive_scan(
_insts.begin(), _insts.end(), id_map.begin(), 0, std::plus<int>{}, [](auto inst) {
return static_cast<int>(inst.type != NOP);
});

// remove the NOP instructions
auto end = std::remove_if(_insts.begin(), _insts.end(), [](auto i) { return i.type == NOP; });
_insts.resize(std::distance(_insts.begin(), end));

// fix up the ids on the remaining instructions using the id_map
std::transform(_insts.begin(), _insts.end(), _insts.begin(), [id_map](auto inst) {
inst.u2.next_id = id_map[inst.u2.next_id];
if (inst.type == OR) { inst.u1.right_id = id_map[inst.u1.right_id]; }
return inst;
});

// fix up the start instruction id too
_startinst_id = id_map[_startinst_id];
}

// expand leading ORs to multiple startinst_ids
void reprog::optimize2()
void reprog::build_start_ids()
{
_startinst_ids.clear();
std::vector<int> stack;
stack.push_back(_startinst_id);
while (!stack.empty()) {
int id = stack.back();
stack.pop_back();
std::stack<int> ids;
ids.push(_startinst_id);
while (!ids.empty()) {
int id = ids.top();
ids.pop();
const reinst& inst = _insts[id];
if (inst.type == OR) {
if (inst.u2.left_id != id) // prevents infinite while-loop here
stack.push_back(inst.u2.left_id);
ids.push(inst.u2.left_id);
if (inst.u1.right_id != id) // prevents infinite while-loop here
stack.push_back(inst.u1.right_id);
ids.push(inst.u1.right_id);
} else {
_startinst_ids.push_back(id);
}
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/strings/regex/regcomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ class reprog {
void set_start_inst(int32_t id);
[[nodiscard]] int32_t get_start_inst() const;

void optimize1();
void optimize2();
void finalize();
void check_for_errors();
#ifndef NDEBUG
void print(regex_flags const flags);
Expand All @@ -139,6 +138,8 @@ class reprog {
int32_t _num_capturing_groups{};

reprog() = default;
void collapse_nops();
void build_start_ids();
void check_for_errors(int32_t id, int32_t next_id);
};

Expand Down