diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index e92da533414..992d66a5ff4 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -19,6 +19,8 @@ #include #include +#include + #include #include #include @@ -50,12 +52,12 @@ enum OperatorType { }; #define ITEM_MASK 0300 -static reclass ccls_w(CCLASS_W); // \w -static reclass ccls_s(CCLASS_S); // \s -static reclass ccls_d(CCLASS_D); // \d -static reclass ccls_W(NCCLASS_W); // \W -static reclass ccls_S(NCCLASS_S); // \S -static reclass ccls_D(NCCLASS_D); // \D +static reclass cclass_w(CCLASS_W); // \w +static reclass cclass_s(CCLASS_S); // \s +static reclass cclass_d(CCLASS_D); // \d +static reclass cclass_W(NCCLASS_W); // \W +static reclass cclass_S(NCCLASS_S); // \S +static reclass cclass_D(NCCLASS_D); // \D // Tables for analyzing quantifiers const std::array valid_preceding_inst_types{{CHAR, CCLASS, NCCLASS, ANY, ANYNL, RBRA}}; @@ -107,16 +109,16 @@ int32_t reprog::add_inst(int32_t t) return add_inst(inst); } -int32_t reprog::add_inst(reinst inst) +int32_t reprog::add_inst(reinst const& inst) { _insts.push_back(inst); - return static_cast(_insts.size() - 1); + return static_cast(_insts.size() - 1); } -int32_t reprog::add_class(reclass cls) +int32_t reprog::add_class(reclass const& cls) { _classes.push_back(cls); - return static_cast(_classes.size() - 1); + return static_cast(_classes.size() - 1); } reinst& reprog::inst_at(int32_t id) { return _insts[id]; } @@ -135,9 +137,11 @@ void reprog::set_groups_count(int32_t groups) { _num_capturing_groups = groups; int32_t reprog::groups_count() const { return _num_capturing_groups; } -const reinst* reprog::insts_data() const { return _insts.data(); } +reinst const* reprog::insts_data() const { return _insts.data(); } + +reclass const* reprog::classes_data() const { return _classes.data(); } -const int32_t* reprog::starts_data() const { return _startinst_ids.data(); } +int32_t const* reprog::starts_data() const { return _startinst_ids.data(); } int32_t reprog::starts_count() const { return static_cast(_startinst_ids.size()); } @@ -209,7 +213,7 @@ class regex_parser { int32_t build_cclass() { int32_t type = CCLASS; - std::vector cls; + std::vector literals; int32_t builtins = 0; auto [is_quoted, chr] = next_char(); @@ -217,19 +221,16 @@ class regex_parser { if (!is_quoted && chr == '^') { type = NCCLASS; std::tie(is_quoted, chr) = next_char(); - // negated classes also do not match '\n' - cls.push_back('\n'); - cls.push_back('\n'); + // negated classes also don't match '\n' + literals.push_back('\n'); + literals.push_back('\n'); } // parse class into a set of spans auto count_char = 0; while (true) { count_char++; - if (chr == 0) { - // malformed '[]' - return 0; - } + if (chr == 0) { return 0; } // malformed '[]' if (is_quoted) { switch (chr) { case 'n': chr = '\n'; break; @@ -239,82 +240,70 @@ class regex_parser { case 'b': chr = 0x08; break; case 'f': chr = 0x0C; break; case 'w': - builtins |= ccls_w.builtins; + builtins |= cclass_w.builtins; std::tie(is_quoted, chr) = next_char(); continue; case 's': - builtins |= ccls_s.builtins; + builtins |= cclass_s.builtins; std::tie(is_quoted, chr) = next_char(); continue; case 'd': - builtins |= ccls_d.builtins; + builtins |= cclass_d.builtins; std::tie(is_quoted, chr) = next_char(); continue; case 'W': - builtins |= ccls_W.builtins; + builtins |= cclass_W.builtins; std::tie(is_quoted, chr) = next_char(); continue; case 'S': - builtins |= ccls_S.builtins; + builtins |= cclass_S.builtins; std::tie(is_quoted, chr) = next_char(); continue; case 'D': - builtins |= ccls_D.builtins; + builtins |= cclass_D.builtins; std::tie(is_quoted, chr) = next_char(); continue; } } - if (!is_quoted && chr == ']' && count_char > 1) break; + if (!is_quoted && chr == ']' && count_char > 1) { break; } // done if (!is_quoted && chr == '-') { - if (cls.empty()) { - // malformed '[]': TODO assert or exception? - return 0; - } + if (literals.empty()) { return 0; } // malformed '[]' std::tie(is_quoted, chr) = next_char(); - if ((!is_quoted && chr == ']') || chr == 0) { - // malformed '[]': TODO assert or exception? - return 0; - } - cls.back() = chr; + if ((!is_quoted && chr == ']') || chr == 0) { return 0; } // malformed '[]' + literals.back() = chr; } else { - cls.push_back(chr); - cls.push_back(chr); + literals.push_back(chr); + literals.push_back(chr); } std::tie(is_quoted, chr) = next_char(); } - /* sort on span start */ - for (std::size_t p = 0; p < cls.size(); p += 2) - for (std::size_t np = p + 2; np < cls.size(); np += 2) - if (cls[np] < cls[p]) { - auto c = cls[np]; - cls[np] = cls[p]; - cls[p] = c; - c = cls[np + 1]; - cls[np + 1] = cls[p + 1]; - cls[p + 1] = c; - } - - /* merge spans */ - reclass yycls{builtins}; - if (cls.size() >= 2) { - int np = 0; - std::size_t p = 0; - yycls.literals += cls[p++]; - yycls.literals += cls[p++]; - for (; p < cls.size(); p += 2) { - /* overlapping or adjacent ranges? */ - if (cls[p] <= yycls.literals[np + 1] + 1) { - if (cls[p + 1] >= yycls.literals[np + 1]) - yycls.literals.replace(np + 1, 1, 1, cls[p + 1]); /* coalesce */ - } else { - np += 2; - yycls.literals += cls[p]; - yycls.literals += cls[p + 1]; + // transform pairs of literals to ranges + std::vector ranges(literals.size() / 2); + auto const counter = thrust::make_counting_iterator(0); + std::transform(counter, counter + ranges.size(), ranges.begin(), [&literals](auto idx) { + return reclass_range{literals[idx * 2], literals[idx * 2 + 1]}; + }); + // sort the ranges to help with detecting overlapping entries + std::sort(ranges.begin(), ranges.end(), [](auto l, auto r) { + return l.first == r.first ? l.last < r.last : l.first < r.first; + }); + // combine overlapping entries: [a-f][c-g] => [a-g] + if (ranges.size() > 1) { + for (auto itr = ranges.begin() + 1; itr < ranges.end(); ++itr) { + auto const prev = *(itr - 1); + if (itr->first <= prev.last + 1) { + // if these 2 ranges intersect, expand the current one + *itr = reclass_range{prev.first, std::max(prev.last, itr->last)}; } } } - _cclass_id = _prog.add_class(yycls); + // remove any duplicates + auto const end = std::unique( + ranges.rbegin(), ranges.rend(), [](auto l, auto r) { return l.first == r.first; }); + ranges.erase(ranges.begin(), ranges.begin() + std::distance(end, ranges.rend())); + + _cclass_id = _prog.add_class(reclass{builtins, std::move(ranges)}); return type; } @@ -363,40 +352,38 @@ class regex_parser { break; } case 'w': { - if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(ccls_w); } + if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(cclass_w); } _cclass_id = _id_cclass_w; return CCLASS; } case 'W': { if (_id_cclass_W < 0) { - reclass cls = ccls_w; - cls.literals += '\n'; - cls.literals += '\n'; + reclass cls = cclass_w; + cls.literals.push_back({'\n', '\n'}); _id_cclass_W = _prog.add_class(cls); } _cclass_id = _id_cclass_W; return NCCLASS; } case 's': { - if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(ccls_s); } + if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); } _cclass_id = _id_cclass_s; return CCLASS; } case 'S': { - if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(ccls_s); } + if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); } _cclass_id = _id_cclass_s; return NCCLASS; } case 'd': { - if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(ccls_d); } + if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(cclass_d); } _cclass_id = _id_cclass_d; return CCLASS; } case 'D': { if (_id_cclass_D < 0) { - reclass cls = ccls_d; - cls.literals += '\n'; - cls.literals += '\n'; + reclass cls = cclass_d; + cls.literals.push_back({'\n', '\n'}); _id_cclass_D = _prog.add_class(cls); } _cclass_id = _id_cclass_D; @@ -1100,15 +1087,16 @@ void reprog::print(regex_flags const flags) const reclass& cls = _classes[i]; auto const size = static_cast(cls.literals.size()); printf("%2d: ", i); - for (int j = 0; j < size; j += 2) { - char32_t c1 = cls.literals[j]; - char32_t c2 = cls.literals[j + 1]; + for (int j = 0; j < size; ++j) { + auto const l = cls.literals[j]; + char32_t c1 = l.first; + char32_t c2 = l.last; if (c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127) { printf("0x%02x-0x%02x", static_cast(c1), static_cast(c2)); } else { printf("%c-%c", static_cast(c1), static_cast(c2)); } - if ((j + 2) < size) { printf(", "); } + if ((j + 1) < size) { printf(", "); } } printf("\n"); if (cls.builtins) { diff --git a/cpp/src/strings/regex/regcomp.h b/cpp/src/strings/regex/regcomp.h index 48395e8cf1f..10092137c77 100644 --- a/cpp/src/strings/regex/regcomp.h +++ b/cpp/src/strings/regex/regcomp.h @@ -47,14 +47,23 @@ enum InstType { END = 0377 // Terminate: match found }; +/** + * @brief Range used for literals in reclass classes. + */ +struct reclass_range { + char32_t first{}; /// first character in span + char32_t last{}; /// last character in span (inclusive) +}; + /** * @brief Class type for regex compiler instruction. */ struct reclass { - int32_t builtins{0}; // bit mask identifying builtin classes - std::u32string literals; // ranges as pairs of utf-8 characters + int32_t builtins{0}; // bit mask identifying builtin classes + std::vector literals; reclass() {} reclass(int m) : builtins(m) {} + reclass(int m, std::vector&& l) : builtins(m), literals(std::move(l)) {} }; constexpr int32_t CCLASS_W{1 << 0}; // [a-z], [A-Z], [0-9], and '_' @@ -105,18 +114,19 @@ class reprog { static reprog create_from(std::string_view pattern, regex_flags const flags); int32_t add_inst(int32_t type); - int32_t add_inst(reinst inst); - int32_t add_class(reclass cls); + int32_t add_inst(reinst const& inst); + int32_t add_class(reclass const& cls); void set_groups_count(int32_t groups); [[nodiscard]] int32_t groups_count() const; - [[nodiscard]] const reinst* insts_data() const; [[nodiscard]] int32_t insts_count() const; - reinst& inst_at(int32_t id); + [[nodiscard]] reinst& inst_at(int32_t id); + [[nodiscard]] reinst const* insts_data() const; - reclass& class_at(int32_t id); [[nodiscard]] int32_t classes_count() const; + [[nodiscard]] reclass& class_at(int32_t id); + [[nodiscard]] reclass const* classes_data() const; [[nodiscard]] const int32_t* starts_data() const; [[nodiscard]] int32_t starts_count() const; diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index 2ee195a2c5e..e899c84a48d 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -51,7 +51,7 @@ constexpr int32_t MINIMUM_THREADS = 256; // Minimum threads for computing w struct alignas(16) reclass_device { int32_t builtins{}; int32_t count{}; - char32_t const* literals{}; + reclass_range const* literals{}; __device__ inline bool is_match(char32_t const ch, uint8_t const* flags) const; }; diff --git a/cpp/src/strings/regex/regex.inl b/cpp/src/strings/regex/regex.inl index 8e2194f2094..fb2db86ab0b 100644 --- a/cpp/src/strings/regex/regex.inl +++ b/cpp/src/strings/regex/regex.inl @@ -141,7 +141,8 @@ __device__ __forceinline__ bool reclass_device::is_match(char32_t const ch, uint8_t const* codepoint_flags) const { for (int i = 0; i < count; ++i) { - if ((ch >= literals[i * 2]) && (ch <= literals[(i * 2) + 1])) { return true; } + auto const literal = literals[i]; + if ((ch >= literal.first) && (ch <= literal.last)) { return true; } } if (!builtins) return false; @@ -204,11 +205,11 @@ __device__ __forceinline__ void reprog_device::store(void* buffer) const auto classes = reinterpret_cast(ptr); result->_classes = classes; // fill in each class - auto d_ptr = reinterpret_cast(classes + _classes_count); + auto d_ptr = reinterpret_cast(classes + _classes_count); for (int idx = 0; idx < _classes_count; ++idx) { classes[idx] = _classes[idx]; classes[idx].literals = d_ptr; - for (int jdx = 0; jdx < _classes[idx].count * 2; ++jdx) + for (int jdx = 0; jdx < _classes[idx].count; ++jdx) *d_ptr++ = _classes[idx].literals[jdx]; } } diff --git a/cpp/src/strings/regex/regexec.cu b/cpp/src/strings/regex/regexec.cu index 16f5b6fa03d..2ee1901c32d 100644 --- a/cpp/src/strings/regex/regexec.cu +++ b/cpp/src/strings/regex/regexec.cu @@ -25,6 +25,8 @@ #include #include +#include +#include namespace cudf { namespace strings { @@ -63,9 +65,12 @@ std::unique_ptr> reprog_devic // compute size of each section auto insts_size = insts_count * sizeof(_insts[0]); auto startids_size = starts_count * sizeof(_startinst_ids[0]); - auto classes_size = classes_count * sizeof(_classes[0]); - for (auto idx = 0; idx < classes_count; ++idx) - classes_size += static_cast((h_prog.class_at(idx).literals.size()) * sizeof(char32_t)); + auto classes_size = std::transform_reduce( + h_prog.classes_data(), + h_prog.classes_data() + h_prog.classes_count(), + classes_count * sizeof(_classes[0]), + std::plus{}, + [&h_prog](auto& cls) { return cls.literals.size() * sizeof(reclass_range); }); // make sure each section is aligned for the subsequent section's data type auto const memsize = cudf::util::round_up_safe(insts_size, sizeof(_startinst_ids[0])) + cudf::util::round_up_safe(startids_size, sizeof(_classes[0])) + @@ -104,14 +109,14 @@ std::unique_ptr> reprog_devic auto d_end = d_ptr + (classes_count * sizeof(reclass_device)); // place each class and append the variable length data for (int32_t idx = 0; idx < classes_count; ++idx) { - reclass& h_class = h_prog.class_at(idx); + auto const& h_class = h_prog.class_at(idx); reclass_device d_class{h_class.builtins, - static_cast(h_class.literals.size() / 2), - reinterpret_cast(d_end)}; + static_cast(h_class.literals.size()), + reinterpret_cast(d_end)}; *classes++ = d_class; - memcpy(h_end, h_class.literals.c_str(), h_class.literals.size() * sizeof(char32_t)); - h_end += h_class.literals.size() * sizeof(char32_t); - d_end += h_class.literals.size() * sizeof(char32_t); + memcpy(h_end, h_class.literals.data(), h_class.literals.size() * sizeof(reclass_range)); + h_end += h_class.literals.size() * sizeof(reclass_range); + d_end += h_class.literals.size() * sizeof(reclass_range); } // initialize the rest of the elements diff --git a/cpp/tests/strings/contains_tests.cpp b/cpp/tests/strings/contains_tests.cpp index 9df22503c07..21c18977746 100644 --- a/cpp/tests/strings/contains_tests.cpp +++ b/cpp/tests/strings/contains_tests.cpp @@ -406,6 +406,23 @@ TEST_F(StringsContainsTests, FixedQuantifier) } } +TEST_F(StringsContainsTests, OverlappedClasses) +{ + auto input = cudf::test::strings_column_wrapper({"abcdefg", "defghí", "", "éééééé", "ghijkl"}); + auto sv = cudf::strings_column_view(input); + + { + auto results = cudf::strings::count_re(sv, "[e-gb-da-c]"); + cudf::test::fixed_width_column_wrapper expected({7, 4, 0, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + } + { + auto results = cudf::strings::count_re(sv, "[á-éê-ú]"); + cudf::test::fixed_width_column_wrapper expected({0, 1, 0, 6, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + } +} + TEST_F(StringsContainsTests, MultiLine) { auto input =