Skip to content

Commit

Permalink
Regex cleanup internal reclass and reclass_device classes (#11045)
Browse files Browse the repository at this point in the history
This cleans up the awkward range literals for supporting the `CCLASS` and `NCCLASS` regex instructions. The range values were always paired (first,last) but arranged consecutively in a flat vector so `[idx] and [idx+1]` were range pairs `idx` was even. This PR introduces a `reclass_range` class that holds the pairs so we can use normal algorithms to manipulate them.

There is some overlap with code changes in PR #10975 

Reference #3582

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Mike Wilson (https://github.com/hyperbolic2346)
  - MithunR (https://github.com/mythrocks)

URL: #11045
  • Loading branch information
davidwendt authored Jun 27, 2022
1 parent 88f047e commit e0003a0
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 101 deletions.
150 changes: 69 additions & 81 deletions cpp/src/strings/regex/regcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <cudf/strings/detail/utf8.hpp>
#include <cudf/utilities/error.hpp>

#include <thrust/iterator/counting_iterator.h>

#include <algorithm>
#include <array>
#include <cctype>
Expand Down Expand Up @@ -50,12 +52,12 @@ enum OperatorType {
};
#define ITEM_MASK 0300

static reclass ccls_w(CCLASS_W); // \w
static reclass ccls_s(CCLASS_S); // \s
static reclass ccls_d(CCLASS_D); // \d
static reclass ccls_W(NCCLASS_W); // \W
static reclass ccls_S(NCCLASS_S); // \S
static reclass ccls_D(NCCLASS_D); // \D
static reclass cclass_w(CCLASS_W); // \w
static reclass cclass_s(CCLASS_S); // \s
static reclass cclass_d(CCLASS_D); // \d
static reclass cclass_W(NCCLASS_W); // \W
static reclass cclass_S(NCCLASS_S); // \S
static reclass cclass_D(NCCLASS_D); // \D

// Tables for analyzing quantifiers
const std::array<int, 6> valid_preceding_inst_types{{CHAR, CCLASS, NCCLASS, ANY, ANYNL, RBRA}};
Expand Down Expand Up @@ -107,16 +109,16 @@ int32_t reprog::add_inst(int32_t t)
return add_inst(inst);
}

int32_t reprog::add_inst(reinst inst)
int32_t reprog::add_inst(reinst const& inst)
{
_insts.push_back(inst);
return static_cast<int>(_insts.size() - 1);
return static_cast<int32_t>(_insts.size() - 1);
}

int32_t reprog::add_class(reclass cls)
int32_t reprog::add_class(reclass const& cls)
{
_classes.push_back(cls);
return static_cast<int>(_classes.size() - 1);
return static_cast<int32_t>(_classes.size() - 1);
}

reinst& reprog::inst_at(int32_t id) { return _insts[id]; }
Expand All @@ -135,9 +137,11 @@ void reprog::set_groups_count(int32_t groups) { _num_capturing_groups = groups;

int32_t reprog::groups_count() const { return _num_capturing_groups; }

const reinst* reprog::insts_data() const { return _insts.data(); }
reinst const* reprog::insts_data() const { return _insts.data(); }

reclass const* reprog::classes_data() const { return _classes.data(); }

const int32_t* reprog::starts_data() const { return _startinst_ids.data(); }
int32_t const* reprog::starts_data() const { return _startinst_ids.data(); }

int32_t reprog::starts_count() const { return static_cast<int>(_startinst_ids.size()); }

Expand Down Expand Up @@ -209,27 +213,24 @@ class regex_parser {
int32_t build_cclass()
{
int32_t type = CCLASS;
std::vector<char32_t> cls;
std::vector<char32_t> literals;
int32_t builtins = 0;

auto [is_quoted, chr] = next_char();
// check for negation
if (!is_quoted && chr == '^') {
type = NCCLASS;
std::tie(is_quoted, chr) = next_char();
// negated classes also do not match '\n'
cls.push_back('\n');
cls.push_back('\n');
// negated classes also don't match '\n'
literals.push_back('\n');
literals.push_back('\n');
}

// parse class into a set of spans
auto count_char = 0;
while (true) {
count_char++;
if (chr == 0) {
// malformed '[]'
return 0;
}
if (chr == 0) { return 0; } // malformed '[]'
if (is_quoted) {
switch (chr) {
case 'n': chr = '\n'; break;
Expand All @@ -239,82 +240,70 @@ class regex_parser {
case 'b': chr = 0x08; break;
case 'f': chr = 0x0C; break;
case 'w':
builtins |= ccls_w.builtins;
builtins |= cclass_w.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 's':
builtins |= ccls_s.builtins;
builtins |= cclass_s.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 'd':
builtins |= ccls_d.builtins;
builtins |= cclass_d.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 'W':
builtins |= ccls_W.builtins;
builtins |= cclass_W.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 'S':
builtins |= ccls_S.builtins;
builtins |= cclass_S.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
case 'D':
builtins |= ccls_D.builtins;
builtins |= cclass_D.builtins;
std::tie(is_quoted, chr) = next_char();
continue;
}
}
if (!is_quoted && chr == ']' && count_char > 1) break;
if (!is_quoted && chr == ']' && count_char > 1) { break; } // done
if (!is_quoted && chr == '-') {
if (cls.empty()) {
// malformed '[]': TODO assert or exception?
return 0;
}
if (literals.empty()) { return 0; } // malformed '[]'
std::tie(is_quoted, chr) = next_char();
if ((!is_quoted && chr == ']') || chr == 0) {
// malformed '[]': TODO assert or exception?
return 0;
}
cls.back() = chr;
if ((!is_quoted && chr == ']') || chr == 0) { return 0; } // malformed '[]'
literals.back() = chr;
} else {
cls.push_back(chr);
cls.push_back(chr);
literals.push_back(chr);
literals.push_back(chr);
}
std::tie(is_quoted, chr) = next_char();
}

/* sort on span start */
for (std::size_t p = 0; p < cls.size(); p += 2)
for (std::size_t np = p + 2; np < cls.size(); np += 2)
if (cls[np] < cls[p]) {
auto c = cls[np];
cls[np] = cls[p];
cls[p] = c;
c = cls[np + 1];
cls[np + 1] = cls[p + 1];
cls[p + 1] = c;
}

/* merge spans */
reclass yycls{builtins};
if (cls.size() >= 2) {
int np = 0;
std::size_t p = 0;
yycls.literals += cls[p++];
yycls.literals += cls[p++];
for (; p < cls.size(); p += 2) {
/* overlapping or adjacent ranges? */
if (cls[p] <= yycls.literals[np + 1] + 1) {
if (cls[p + 1] >= yycls.literals[np + 1])
yycls.literals.replace(np + 1, 1, 1, cls[p + 1]); /* coalesce */
} else {
np += 2;
yycls.literals += cls[p];
yycls.literals += cls[p + 1];
// transform pairs of literals to ranges
std::vector<reclass_range> ranges(literals.size() / 2);
auto const counter = thrust::make_counting_iterator(0);
std::transform(counter, counter + ranges.size(), ranges.begin(), [&literals](auto idx) {
return reclass_range{literals[idx * 2], literals[idx * 2 + 1]};
});
// sort the ranges to help with detecting overlapping entries
std::sort(ranges.begin(), ranges.end(), [](auto l, auto r) {
return l.first == r.first ? l.last < r.last : l.first < r.first;
});
// combine overlapping entries: [a-f][c-g] => [a-g]
if (ranges.size() > 1) {
for (auto itr = ranges.begin() + 1; itr < ranges.end(); ++itr) {
auto const prev = *(itr - 1);
if (itr->first <= prev.last + 1) {
// if these 2 ranges intersect, expand the current one
*itr = reclass_range{prev.first, std::max(prev.last, itr->last)};
}
}
}
_cclass_id = _prog.add_class(yycls);
// remove any duplicates
auto const end = std::unique(
ranges.rbegin(), ranges.rend(), [](auto l, auto r) { return l.first == r.first; });
ranges.erase(ranges.begin(), ranges.begin() + std::distance(end, ranges.rend()));

_cclass_id = _prog.add_class(reclass{builtins, std::move(ranges)});
return type;
}

Expand Down Expand Up @@ -363,40 +352,38 @@ class regex_parser {
break;
}
case 'w': {
if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(ccls_w); }
if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(cclass_w); }
_cclass_id = _id_cclass_w;
return CCLASS;
}
case 'W': {
if (_id_cclass_W < 0) {
reclass cls = ccls_w;
cls.literals += '\n';
cls.literals += '\n';
reclass cls = cclass_w;
cls.literals.push_back({'\n', '\n'});
_id_cclass_W = _prog.add_class(cls);
}
_cclass_id = _id_cclass_W;
return NCCLASS;
}
case 's': {
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(ccls_s); }
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); }
_cclass_id = _id_cclass_s;
return CCLASS;
}
case 'S': {
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(ccls_s); }
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); }
_cclass_id = _id_cclass_s;
return NCCLASS;
}
case 'd': {
if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(ccls_d); }
if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(cclass_d); }
_cclass_id = _id_cclass_d;
return CCLASS;
}
case 'D': {
if (_id_cclass_D < 0) {
reclass cls = ccls_d;
cls.literals += '\n';
cls.literals += '\n';
reclass cls = cclass_d;
cls.literals.push_back({'\n', '\n'});
_id_cclass_D = _prog.add_class(cls);
}
_cclass_id = _id_cclass_D;
Expand Down Expand Up @@ -1100,15 +1087,16 @@ void reprog::print(regex_flags const flags)
const reclass& cls = _classes[i];
auto const size = static_cast<int>(cls.literals.size());
printf("%2d: ", i);
for (int j = 0; j < size; j += 2) {
char32_t c1 = cls.literals[j];
char32_t c2 = cls.literals[j + 1];
for (int j = 0; j < size; ++j) {
auto const l = cls.literals[j];
char32_t c1 = l.first;
char32_t c2 = l.last;
if (c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127) {
printf("0x%02x-0x%02x", static_cast<unsigned>(c1), static_cast<unsigned>(c2));
} else {
printf("%c-%c", static_cast<char>(c1), static_cast<char>(c2));
}
if ((j + 2) < size) { printf(", "); }
if ((j + 1) < size) { printf(", "); }
}
printf("\n");
if (cls.builtins) {
Expand Down
24 changes: 17 additions & 7 deletions cpp/src/strings/regex/regcomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,23 @@ enum InstType {
END = 0377 // Terminate: match found
};

/**
* @brief Range used for literals in reclass classes.
*/
struct reclass_range {
char32_t first{}; /// first character in span
char32_t last{}; /// last character in span (inclusive)
};

/**
* @brief Class type for regex compiler instruction.
*/
struct reclass {
int32_t builtins{0}; // bit mask identifying builtin classes
std::u32string literals; // ranges as pairs of utf-8 characters
int32_t builtins{0}; // bit mask identifying builtin classes
std::vector<reclass_range> literals;
reclass() {}
reclass(int m) : builtins(m) {}
reclass(int m, std::vector<reclass_range>&& l) : builtins(m), literals(std::move(l)) {}
};

constexpr int32_t CCLASS_W{1 << 0}; // [a-z], [A-Z], [0-9], and '_'
Expand Down Expand Up @@ -105,18 +114,19 @@ class reprog {
static reprog create_from(std::string_view pattern, regex_flags const flags);

int32_t add_inst(int32_t type);
int32_t add_inst(reinst inst);
int32_t add_class(reclass cls);
int32_t add_inst(reinst const& inst);
int32_t add_class(reclass const& cls);

void set_groups_count(int32_t groups);
[[nodiscard]] int32_t groups_count() const;

[[nodiscard]] const reinst* insts_data() const;
[[nodiscard]] int32_t insts_count() const;
reinst& inst_at(int32_t id);
[[nodiscard]] reinst& inst_at(int32_t id);
[[nodiscard]] reinst const* insts_data() const;

reclass& class_at(int32_t id);
[[nodiscard]] int32_t classes_count() const;
[[nodiscard]] reclass& class_at(int32_t id);
[[nodiscard]] reclass const* classes_data() const;

[[nodiscard]] const int32_t* starts_data() const;
[[nodiscard]] int32_t starts_count() const;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/strings/regex/regex.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ constexpr int32_t MINIMUM_THREADS = 256; // Minimum threads for computing w
struct alignas(16) reclass_device {
int32_t builtins{};
int32_t count{};
char32_t const* literals{};
reclass_range const* literals{};

__device__ inline bool is_match(char32_t const ch, uint8_t const* flags) const;
};
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ __device__ __forceinline__ bool reclass_device::is_match(char32_t const ch,
uint8_t const* codepoint_flags) const
{
for (int i = 0; i < count; ++i) {
if ((ch >= literals[i * 2]) && (ch <= literals[(i * 2) + 1])) { return true; }
auto const literal = literals[i];
if ((ch >= literal.first) && (ch <= literal.last)) { return true; }
}

if (!builtins) return false;
Expand Down Expand Up @@ -204,11 +205,11 @@ __device__ __forceinline__ void reprog_device::store(void* buffer) const
auto classes = reinterpret_cast<reclass_device*>(ptr);
result->_classes = classes;
// fill in each class
auto d_ptr = reinterpret_cast<char32_t*>(classes + _classes_count);
auto d_ptr = reinterpret_cast<reclass_range*>(classes + _classes_count);
for (int idx = 0; idx < _classes_count; ++idx) {
classes[idx] = _classes[idx];
classes[idx].literals = d_ptr;
for (int jdx = 0; jdx < _classes[idx].count * 2; ++jdx)
for (int jdx = 0; jdx < _classes[idx].count; ++jdx)
*d_ptr++ = _classes[idx].literals[jdx];
}
}
Expand Down
Loading

0 comments on commit e0003a0

Please sign in to comment.