Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regex cleanup internal reclass and reclass_device classes #11045

Merged
merged 16 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 69 additions & 82 deletions cpp/src/strings/regex/regcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <cudf/strings/detail/utf8.hpp>
#include <cudf/utilities/error.hpp>

#include <thrust/iterator/counting_iterator.h>

#include <algorithm>
#include <array>
#include <cctype>
Expand Down Expand Up @@ -50,12 +52,12 @@ enum OperatorType {
};
#define ITEM_MASK 0300

static reclass ccls_w(CCLASS_W); // \w
static reclass ccls_s(CCLASS_S); // \s
static reclass ccls_d(CCLASS_D); // \d
static reclass ccls_W(NCCLASS_W); // \W
static reclass ccls_S(NCCLASS_S); // \S
static reclass ccls_D(NCCLASS_D); // \D
static reclass cclass_w(CCLASS_W); // \w
static reclass cclass_s(CCLASS_S); // \s
static reclass cclass_d(CCLASS_D); // \d
static reclass cclass_W(NCCLASS_W); // \W
static reclass cclass_S(NCCLASS_S); // \S
static reclass cclass_D(NCCLASS_D); // \D

// Tables for analyzing quantifiers
const std::array<int, 6> valid_preceding_inst_types{{CHAR, CCLASS, NCCLASS, ANY, ANYNL, RBRA}};
Expand Down Expand Up @@ -107,16 +109,16 @@ int32_t reprog::add_inst(int32_t t)
return add_inst(inst);
}

int32_t reprog::add_inst(reinst inst)
int32_t reprog::add_inst(reinst const& inst)
{
_insts.push_back(inst);
return static_cast<int>(_insts.size() - 1);
return static_cast<int32_t>(_insts.size() - 1);
}

int32_t reprog::add_class(reclass cls)
int32_t reprog::add_class(reclass const& cls)
{
_classes.push_back(cls);
return static_cast<int>(_classes.size() - 1);
return static_cast<int32_t>(_classes.size() - 1);
}

reinst& reprog::inst_at(int32_t id) { return _insts[id]; }
Expand All @@ -135,9 +137,11 @@ void reprog::set_groups_count(int32_t groups) { _num_capturing_groups = groups;

int32_t reprog::groups_count() const { return _num_capturing_groups; }

const reinst* reprog::insts_data() const { return _insts.data(); }
reinst const* reprog::insts_data() const { return _insts.data(); }

reclass const* reprog::classes_data() const { return _classes.data(); }

const int32_t* reprog::starts_data() const { return _startinst_ids.data(); }
int32_t const* reprog::starts_data() const { return _startinst_ids.data(); }

int32_t reprog::starts_count() const { return static_cast<int>(_startinst_ids.size()); }

Expand Down Expand Up @@ -209,27 +213,24 @@ class regex_parser {
int32_t build_cclass()
{
int32_t type = CCLASS;
std::vector<char32_t> cls;
std::vector<char32_t> literals;
int32_t builtins = 0;

auto [is_quoted, chr] = next_char();
// check for negation
if (!is_quoted && chr == '^') {
type = NCCLASS;
std::tie(is_quoted, chr) = next_char();
// negated classes also do not match '\n'
cls.push_back('\n');
cls.push_back('\n');
// negated classes also don't match '\n'
literals.push_back('\n');
literals.push_back('\n');
}

// parse class into a set of spans
auto count_char = 0;
while (true) {
count_char++;
if (chr == 0) {
// malformed '[]'
return 0;
}
if (chr == 0) { return 0; } // malformed '[]'
if (is_quoted) {
switch (chr) {
case 'n': chr = '\n'; break;
Expand All @@ -239,82 +240,69 @@ class regex_parser {
case 'b': chr = 0x08; break;
case 'f': chr = 0x0C; break;
case 'w':
builtins |= ccls_w.builtins;
builtins |= cclass_w.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 's':
builtins |= ccls_s.builtins;
builtins |= cclass_s.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 'd':
builtins |= ccls_d.builtins;
builtins |= cclass_d.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 'W':
builtins |= ccls_W.builtins;
builtins |= cclass_W.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 'S':
builtins |= ccls_S.builtins;
builtins |= cclass_S.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 'D':
builtins |= ccls_D.builtins;
builtins |= cclass_D.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
}
}
if (!is_quoted && chr == ']' && count_char > 1) break;
if (!is_quoted && chr == ']' && count_char > 1) { break; } // done
if (!is_quoted && chr == '-') {
if (cls.empty()) {
// malformed '[]': TODO assert or exception?
return 0;
}
if (literals.empty()) { return 0; } // malformed '[]'
std::tie(is_quoted, chr) = next_char();
if ((!is_quoted && chr == ']') || chr == 0) {
// malformed '[]': TODO assert or exception?
return 0;
}
cls.back() = chr;
if ((!is_quoted && chr == ']') || chr == 0) { return 0; } // malformed '[]'
literals.back() = chr;
} else {
cls.push_back(chr);
cls.push_back(chr);
literals.push_back(chr);
literals.push_back(chr);
}
std::tie(is_quoted, chr) = next_char();
}

/* sort on span start */
for (std::size_t p = 0; p < cls.size(); p += 2)
for (std::size_t np = p + 2; np < cls.size(); np += 2)
if (cls[np] < cls[p]) {
auto c = cls[np];
cls[np] = cls[p];
cls[p] = c;
c = cls[np + 1];
cls[np + 1] = cls[p + 1];
cls[p + 1] = c;
}

/* merge spans */
reclass yycls{builtins};
if (cls.size() >= 2) {
int np = 0;
std::size_t p = 0;
yycls.literals += cls[p++];
yycls.literals += cls[p++];
for (; p < cls.size(); p += 2) {
/* overlapping or adjacent ranges? */
if (cls[p] <= yycls.literals[np + 1] + 1) {
if (cls[p + 1] >= yycls.literals[np + 1])
yycls.literals.replace(np + 1, 1, 1, cls[p + 1]); /* coalesce */
} else {
np += 2;
yycls.literals += cls[p];
yycls.literals += cls[p + 1];
}
// transform pairs of literals to ranges
std::vector<reclass_range> ranges(literals.size() / 2);
auto const counter = thrust::make_counting_iterator(0);
std::transform(counter, counter + ranges.size(), ranges.begin(), [&literals](auto idx) {
return reclass_range{literals[idx * 2], literals[idx * 2 + 1]};
});
// sort the ranges to help with detecting overlapping entries
std::sort(ranges.begin(), ranges.end(), [](auto l, auto r) {
return l.first == r.first ? l.last < r.last : l.first < r.first;
});
// combine overlapping entries: [a-f][c-g] => [a-g]
for (auto itr = ranges.begin() + static_cast<int>(!ranges.empty()); itr < ranges.end(); ++itr) {
hyperbolic2346 marked this conversation as resolved.
Show resolved Hide resolved
auto const prev = *(itr - 1);
if (itr->first <= prev.last + 1) {
// if these 2 ranges intersect, expand the current one
*itr = reclass_range{prev.first, std::max(prev.last, itr->last)};
hyperbolic2346 marked this conversation as resolved.
Show resolved Hide resolved
}
}
_cclass_id = _prog.add_class(yycls);
// remove any duplicates
std::reverse(ranges.begin(), ranges.end()); // moves larger overlaps forward
hyperbolic2346 marked this conversation as resolved.
Show resolved Hide resolved
auto const end = // std::unique specifies keeping the first entry in a repeated sequence
std::unique(ranges.begin(), ranges.end(), [](auto l, auto r) { return l.first == r.first; });
ranges.erase(end, ranges.end()); // clear the remaining items

_cclass_id = _prog.add_class(reclass{builtins, std::move(ranges)});
return type;
}

Expand Down Expand Up @@ -363,40 +351,38 @@ class regex_parser {
break;
}
case 'w': {
if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(ccls_w); }
if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(cclass_w); }
_cclass_id = _id_cclass_w;
return CCLASS;
}
case 'W': {
if (_id_cclass_W < 0) {
reclass cls = ccls_w;
cls.literals += '\n';
cls.literals += '\n';
reclass cls = cclass_w;
cls.literals.push_back({'\n', '\n'});
_id_cclass_W = _prog.add_class(cls);
}
_cclass_id = _id_cclass_W;
return NCCLASS;
}
case 's': {
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(ccls_s); }
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); }
_cclass_id = _id_cclass_s;
return CCLASS;
}
case 'S': {
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(ccls_s); }
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); }
_cclass_id = _id_cclass_s;
return NCCLASS;
}
case 'd': {
if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(ccls_d); }
if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(cclass_d); }
_cclass_id = _id_cclass_d;
return CCLASS;
}
case 'D': {
if (_id_cclass_D < 0) {
reclass cls = ccls_d;
cls.literals += '\n';
cls.literals += '\n';
reclass cls = cclass_d;
cls.literals.push_back({'\n', '\n'});
_id_cclass_D = _prog.add_class(cls);
}
_cclass_id = _id_cclass_D;
Expand Down Expand Up @@ -1100,15 +1086,16 @@ void reprog::print(regex_flags const flags)
const reclass& cls = _classes[i];
auto const size = static_cast<int>(cls.literals.size());
printf("%2d: ", i);
for (int j = 0; j < size; j += 2) {
char32_t c1 = cls.literals[j];
char32_t c2 = cls.literals[j + 1];
for (int j = 0; j < size; ++j) {
hyperbolic2346 marked this conversation as resolved.
Show resolved Hide resolved
auto const l = cls.literals[j];
char32_t c1 = l.first;
char32_t c2 = l.last;
if (c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127) {
printf("0x%02x-0x%02x", static_cast<unsigned>(c1), static_cast<unsigned>(c2));
} else {
printf("%c-%c", static_cast<char>(c1), static_cast<char>(c2));
}
if ((j + 2) < size) { printf(", "); }
if ((j + 1) < size) { printf(", "); }
}
printf("\n");
if (cls.builtins) {
Expand Down
24 changes: 17 additions & 7 deletions cpp/src/strings/regex/regcomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,23 @@ enum InstType {
END = 0377 // Terminate: match found
};

/**
* @brief Range used for literals in reclass classes.
*/
struct reclass_range {
char32_t first{}; /// first character in span
char32_t last{}; /// last character in span (inclusive)
};

/**
* @brief Class type for regex compiler instruction.
*/
struct reclass {
int32_t builtins{0}; // bit mask identifying builtin classes
std::u32string literals; // ranges as pairs of utf-8 characters
int32_t builtins{0}; // bit mask identifying builtin classes
std::vector<reclass_range> literals;
reclass() {}
reclass(int m) : builtins(m) {}
reclass(int m, std::vector<reclass_range>&& l) : builtins(m), literals(std::move(l)) {}
};

constexpr int32_t CCLASS_W{1 << 0}; // [a-z], [A-Z], [0-9], and '_'
Expand Down Expand Up @@ -105,18 +114,19 @@ class reprog {
static reprog create_from(std::string_view pattern, regex_flags const flags);

int32_t add_inst(int32_t type);
int32_t add_inst(reinst inst);
int32_t add_class(reclass cls);
int32_t add_inst(reinst const& inst);
int32_t add_class(reclass const& cls);

void set_groups_count(int32_t groups);
[[nodiscard]] int32_t groups_count() const;

[[nodiscard]] const reinst* insts_data() const;
[[nodiscard]] int32_t insts_count() const;
reinst& inst_at(int32_t id);
[[nodiscard]] reinst& inst_at(int32_t id);
[[nodiscard]] reinst const* insts_data() const;

reclass& class_at(int32_t id);
[[nodiscard]] int32_t classes_count() const;
[[nodiscard]] reclass& class_at(int32_t id);
[[nodiscard]] reclass const* classes_data() const;

[[nodiscard]] const int32_t* starts_data() const;
[[nodiscard]] int32_t starts_count() const;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/strings/regex/regex.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ constexpr int32_t MINIMUM_THREADS = 256; // Minimum threads for computing w
struct alignas(16) reclass_device {
int32_t builtins{};
int32_t count{};
char32_t const* literals{};
reclass_range const* literals{};

__device__ inline bool is_match(char32_t const ch, uint8_t const* flags) const;
};
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ __device__ __forceinline__ bool reclass_device::is_match(char32_t const ch,
uint8_t const* codepoint_flags) const
{
for (int i = 0; i < count; ++i) {
if ((ch >= literals[i * 2]) && (ch <= literals[(i * 2) + 1])) { return true; }
auto const literal = literals[i];
if ((ch >= literal.first) && (ch <= literal.last)) { return true; }
}

if (!builtins) return false;
Expand Down Expand Up @@ -204,11 +205,11 @@ __device__ __forceinline__ void reprog_device::store(void* buffer) const
auto classes = reinterpret_cast<reclass_device*>(ptr);
result->_classes = classes;
// fill in each class
auto d_ptr = reinterpret_cast<char32_t*>(classes + _classes_count);
auto d_ptr = reinterpret_cast<reclass_range*>(classes + _classes_count);
for (int idx = 0; idx < _classes_count; ++idx) {
classes[idx] = _classes[idx];
classes[idx].literals = d_ptr;
for (int jdx = 0; jdx < _classes[idx].count * 2; ++jdx)
for (int jdx = 0; jdx < _classes[idx].count; ++jdx)
*d_ptr++ = _classes[idx].literals[jdx];
}
}
Expand Down
Loading