From ac27757092e9ba2bc0656b6a7dfbc79ce8b5e76a Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Thu, 14 Apr 2022 08:10:26 -0400 Subject: [PATCH] Cleanup libcudf strings regex classes (#10573) Refactors some of the internal libcudf regex classes used for executing regex on strings. This is the first part of some changes to reduce kernel memory launch size for the regex code. A follow on PR will change the stack-based state management to a device memory approach. The changes here are isolated to help ease the review process in the next PR. Mostly code has been moved or refactored along with general cleanup like adding consts and removing some unnecessary pass-by-reference/pointer. None of the calling routines currently require changes and no behavior has changed. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Robert Maynard (https://github.com/robertmaynard) - Jake Hemstad (https://github.com/jrhemstad) URL: https://github.com/rapidsai/cudf/pull/10573 --- cpp/src/strings/regex/regcomp.cpp | 37 +++- cpp/src/strings/regex/regcomp.h | 2 +- cpp/src/strings/regex/regex.cuh | 109 ++++----- cpp/src/strings/regex/regex.inl | 352 ++++++++++++++---------------- cpp/src/strings/regex/regexec.cu | 130 +++++------ 5 files changed, 313 insertions(+), 317 deletions(-) diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index 6f36658523b..829230d0842 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -16,6 +16,7 @@ #include +#include #include #include @@ -58,6 +59,37 @@ const std::array escapable_chars{ {'.', '-', '+', '*', '\\', '?', '^', '$', '|', '{', '}', '(', ')', '[', ']', '<', '>', '"', '~', '\'', '`', '_', '@', '=', ';', ':', '!', '#', '%', '&', ',', '/', ' '}}; +/** + * @brief Converts UTF-8 string into fixed-width 32-bit character vector. + * + * No character conversion occurs. + * Each UTF-8 character is promoted into a 32-bit value. + * The last entry in the returned vector will be a 0 value. + * The fixed-width vector makes it easier to compile and faster to execute. + * + * @param pattern Regular expression encoded with UTF-8. + * @return Fixed-width 32-bit character vector. + */ +std::vector string_to_char32_vector(std::string_view pattern) +{ + size_type size = static_cast(pattern.size()); + size_type count = std::count_if(pattern.cbegin(), pattern.cend(), [](char ch) { + return is_begin_utf8_char(static_cast(ch)); + }); + std::vector result(count + 1); + char32_t* output_ptr = result.data(); + const char* input_ptr = pattern.data(); + for (size_type idx = 0; idx < size; ++idx) { + char_utf8 output_character = 0; + size_type ch_width = to_char_utf8(input_ptr, output_character); + input_ptr += ch_width; + idx += ch_width - 1; + *output_ptr++ = output_character; + } + result[count] = 0; // last entry set to 0 + return result; +} + } // namespace int32_t reprog::add_inst(int32_t t) @@ -838,10 +870,11 @@ class regex_compiler { }; // Convert pattern into program -reprog reprog::create_from(const char32_t* pattern, regex_flags const flags) +reprog reprog::create_from(std::string_view pattern, regex_flags const flags) { reprog rtn; - regex_compiler compiler(pattern, flags, rtn); + auto pattern32 = string_to_char32_vector(pattern); + regex_compiler compiler(pattern32.data(), flags, rtn); // for debugging, it can be helpful to call rtn.print(flags) here to dump // out the instructions that have been created from the given pattern return rtn; diff --git a/cpp/src/strings/regex/regcomp.h b/cpp/src/strings/regex/regcomp.h index 18735d0f980..798b43830b4 100644 --- a/cpp/src/strings/regex/regcomp.h +++ b/cpp/src/strings/regex/regcomp.h @@ -92,7 +92,7 @@ class reprog { * @brief Parses the given regex pattern and compiles * into a list of chained instructions. */ - static reprog create_from(const char32_t* pattern, regex_flags const flags); + static reprog create_from(std::string_view pattern, regex_flags const flags); int32_t add_inst(int32_t type); int32_t add_inst(reinst inst); diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index b172ceae2a6..bcdd15bceda 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -25,7 +25,6 @@ #include #include -#include #include namespace cudf { @@ -35,9 +34,7 @@ class string_view; namespace strings { namespace detail { -struct reljunk; -struct reinst; -class reprog; +struct relist; using match_pair = thrust::pair; using match_result = thrust::optional; @@ -65,19 +62,18 @@ constexpr int32_t RX_LARGE_INSTS = (RX_STACK_LARGE / 11); * * This class holds the unique data for any regex CCLASS instruction. */ -class reclass_device { - public: +struct alignas(16) reclass_device { int32_t builtins{}; int32_t count{}; - char32_t* literals{}; + char32_t const* literals{}; - __device__ bool is_match(char32_t ch, const uint8_t* flags); + __device__ inline bool is_match(char32_t const ch, uint8_t const* flags) const; }; /** * @brief Regex program of instructions/data for a specific regex pattern. * - * Once create, this find/extract methods are used to evaluating the regex instructions + * Once created, the find/extract methods are used to evaluate the regex instructions * against a single string. */ class reprog_device { @@ -132,15 +128,7 @@ class reprog_device { /** * @brief Returns the number of regex instructions. */ - [[nodiscard]] __host__ __device__ int32_t insts_counts() const { return _insts_count; } - - /** - * @brief Returns true if this is an empty program. - */ - [[nodiscard]] __device__ bool is_empty() const - { - return insts_counts() == 0 || get_inst(0)->type == END; - } + [[nodiscard]] CUDF_HOST_DEVICE int32_t insts_counts() const { return _insts_count; } /** * @brief Returns the number of regex groups found in the expression. @@ -151,19 +139,9 @@ class reprog_device { } /** - * @brief Returns the regex instruction object for a given index. - */ - [[nodiscard]] __device__ inline reinst* get_inst(int32_t idx) const; - - /** - * @brief Returns the regex class object for a given index. - */ - [[nodiscard]] __device__ inline reclass_device get_class(int32_t idx) const; - - /** - * @brief Returns the start-instruction-ids vector. + * @brief Returns true if this is an empty program. */ - [[nodiscard]] __device__ inline int32_t* startinst_ids() const; + [[nodiscard]] __device__ inline bool is_empty() const; /** * @brief Does a find evaluation using the compiled expression on the given string. @@ -180,9 +158,9 @@ class reprog_device { */ template __device__ inline int32_t find(int32_t idx, - string_view const& d_str, - int32_t& begin, - int32_t& end); + string_view const d_str, + cudf::size_type& begin, + cudf::size_type& end) const; /** * @brief Does an extract evaluation using the compiled expression on the given string. @@ -192,8 +170,8 @@ class reprog_device { * the matched section. * * @tparam stack_size One of the `RX_STACK_` values based on the `insts_count`. - * @param idx The string index used for mapping the state memory for this string in global memory - * (if necessary). + * @param idx The string index used for mapping the state memory for this string in global + * memory (if necessary). * @param d_str The string to search. * @param begin Position index to begin the search. If found, returns the position found * in the string. @@ -204,34 +182,65 @@ class reprog_device { */ template __device__ inline match_result extract(cudf::size_type idx, - string_view const& d_str, + string_view const d_str, cudf::size_type begin, cudf::size_type end, - cudf::size_type group_id); + cudf::size_type const group_id) const; private: - int32_t _startinst_id, _num_capturing_groups; - int32_t _insts_count, _starts_count, _classes_count; - const uint8_t* _codepoint_flags{}; // table of character types - reinst* _insts{}; // array of regex instructions - int32_t* _startinst_ids{}; // array of start instruction ids - reclass_device* _classes{}; // array of regex classes - void* _relists_mem{}; // runtime relist memory for regexec + struct reljunk { + relist* __restrict__ list1; + relist* __restrict__ list2; + int32_t starttype{}; + char32_t startchar{}; + + __device__ inline reljunk(relist* list1, relist* list2, reinst const inst); + __device__ inline void swaplist(); + }; + + /** + * @brief Returns the regex instruction object for a given id. + */ + __device__ inline reinst get_inst(int32_t id) const; + + /** + * @brief Returns the regex class object for a given id. + */ + __device__ inline reclass_device get_class(int32_t id) const; /** * @brief Executes the regex pattern on the given string. */ - __device__ inline int32_t regexec( - string_view const& d_str, reljunk& jnk, int32_t& begin, int32_t& end, int32_t group_id = 0); + __device__ inline int32_t regexec(string_view const d_str, + reljunk jnk, + cudf::size_type& begin, + cudf::size_type& end, + cudf::size_type const group_id = 0) const; /** * @brief Utility wrapper to setup state memory structures for calling regexec */ template - __device__ inline int32_t call_regexec( - int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t group_id = 0); - - reprog_device(reprog&); // must use create() + __device__ inline int32_t call_regexec(int32_t idx, + string_view const d_str, + cudf::size_type& begin, + cudf::size_type& end, + cudf::size_type const group_id = 0) const; + + reprog_device(reprog&); + + int32_t _startinst_id; // first instruction id + int32_t _num_capturing_groups; // instruction groups + int32_t _insts_count; // number of instructions + int32_t _starts_count; // number of start-insts ids + int32_t _classes_count; // number of classes + + uint8_t const* _codepoint_flags{}; // table of character types + reinst const* _insts{}; // array of regex instructions + int32_t const* _startinst_ids{}; // array of start instruction ids + reclass_device const* _classes{}; // array of regex classes + + void* _relists_mem{}; // runtime relist memory for regexec() }; } // namespace detail diff --git a/cpp/src/strings/regex/regex.inl b/cpp/src/strings/regex/regex.inl index 01e773960e4..9fe4440d7ec 100644 --- a/cpp/src/strings/regex/regex.inl +++ b/cpp/src/strings/regex/regex.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,16 +17,9 @@ #include #include -#include +#include #include -#include -#include -#include -#include -#include -#include - namespace cudf { namespace strings { namespace detail { @@ -40,95 +33,102 @@ namespace detail { * reflected here. The regexec function updates and manages this state data. */ struct alignas(8) relist { - int16_t size{}; - int16_t listsize{}; - int32_t reserved; - int2* ranges{}; // pair per instruction - int16_t* inst_ids{}; // one per instruction - u_char* mask{}; // bit per instruction - - CUDF_HOST_DEVICE inline static int32_t data_size_for(int32_t insts) + /** + * @brief Compute the memory size for the state data. + */ + constexpr inline static std::size_t data_size_for(int32_t insts) { - return ((sizeof(ranges[0]) + sizeof(inst_ids[0])) * insts) + ((insts + 7) / 8); + return ((sizeof(ranges[0]) + sizeof(inst_ids[0])) * insts) + + cudf::util::div_rounding_up_unsafe(insts, 8); } - CUDF_HOST_DEVICE inline static int32_t alloc_size(int32_t insts) + /** + * @brief Compute the aligned memory allocation size. + */ + constexpr inline static std::size_t alloc_size(int32_t insts) { - int32_t size = sizeof(relist); - size += data_size_for(insts); - size = ((size + 7) / 8) * 8; // align it too - return size; + return cudf::util::round_up_unsafe(data_size_for(insts) + sizeof(relist), + sizeof(ranges[0])); } - CUDF_HOST_DEVICE inline relist() {} + struct alignas(16) restate { + int2 range; + int32_t inst_id; + int32_t reserved; + }; - CUDF_HOST_DEVICE inline relist(int16_t insts, u_char* data = nullptr) : listsize(insts) + __device__ __forceinline__ relist(int16_t insts, u_char* data = nullptr) + : masksize(cudf::util::div_rounding_up_unsafe(insts, 8)) { auto ptr = data == nullptr ? reinterpret_cast(this) + sizeof(relist) : data; ranges = reinterpret_cast(ptr); - ptr += listsize * sizeof(ranges[0]); + ptr += insts * sizeof(ranges[0]); inst_ids = reinterpret_cast(ptr); - ptr += listsize * sizeof(inst_ids[0]); + ptr += insts * sizeof(inst_ids[0]); mask = ptr; reset(); } - CUDF_HOST_DEVICE inline void reset() + __device__ __forceinline__ void reset() { - memset(mask, 0, (listsize + 7) / 8); + memset(mask, 0, masksize); size = 0; } - __device__ inline bool activate(int32_t i, int32_t begin, int32_t end) + __device__ __forceinline__ bool activate(int32_t id, int32_t begin, int32_t end) { - if (readMask(i)) return false; - writeMask(true, i); - inst_ids[size] = static_cast(i); + if (readMask(id)) { return false; } + writeMask(id); + inst_ids[size] = static_cast(id); ranges[size] = int2{begin, end}; ++size; return true; } - __device__ inline void writeMask(bool v, int32_t pos) + __device__ __forceinline__ restate get_state(int16_t idx) const { - u_char uc = 1 << (pos & 7); - if (v) - mask[pos >> 3] |= uc; - else - mask[pos >> 3] &= ~uc; + return restate{ranges[idx], inst_ids[idx]}; } - __device__ inline bool readMask(int32_t pos) + __device__ __forceinline__ int16_t get_size() const { return size; } + + private: + int16_t size{}; + int16_t const masksize; + int32_t reserved; + int2* __restrict__ ranges; // pair per instruction + int16_t* __restrict__ inst_ids; // one per instruction + u_char* __restrict__ mask; // bit per instruction + + __device__ __forceinline__ void writeMask(int32_t pos) const { - u_char uc = mask[pos >> 3]; - return static_cast((uc >> (pos & 7)) & 1); + u_char const uc = 1 << (pos & 7); + mask[pos >> 3] |= uc; } -}; -/** - * @brief This manages the two relist instances required by the regexec function. - */ -struct reljunk { - relist* list1; - relist* list2; - int32_t starttype{}; - char32_t startchar{}; - - __host__ __device__ reljunk(relist* list1, relist* list2, int32_t stype, char32_t schar) - : list1(list1), list2(list2) + __device__ __forceinline__ bool readMask(int32_t pos) const { - if (starttype == CHAR || starttype == BOL) { - starttype = stype; - startchar = schar; - } + u_char const uc = mask[pos >> 3]; + return static_cast((uc >> (pos & 7)) & 1); } }; -__device__ inline void swaplist(relist*& l1, relist*& l2) +__device__ __forceinline__ reprog_device::reljunk::reljunk(relist* list1, + relist* list2, + reinst const inst) + : list1(list1), list2(list2) +{ + if (inst.type == CHAR || inst.type == BOL) { + starttype = inst.type; + startchar = inst.u1.c; + } +} + +__device__ __forceinline__ void reprog_device::reljunk::swaplist() { - relist* tmp = l1; - l1 = l2; - l2 = tmp; + auto tmp = list1; + list1 = list2; + list2 = tmp; } /** @@ -138,15 +138,13 @@ __device__ inline void swaplist(relist*& l1, relist*& l2) * @param codepoint_flags Used for mapping a character to type for builtin classes. * @return true if the character matches */ -__device__ inline bool reclass_device::is_match(char32_t ch, const uint8_t* codepoint_flags) +__device__ __forceinline__ bool reclass_device::is_match(char32_t const ch, + uint8_t const* codepoint_flags) const { - if (thrust::any_of(thrust::seq, - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(count), - [ch, this] __device__(int i) { - return ((ch >= literals[i * 2]) && (ch <= literals[(i * 2) + 1])); - })) - return true; + for (int i = 0; i < count; ++i) { + if ((ch >= literals[i * 2]) && (ch <= literals[(i * 2) + 1])) { return true; } + } + if (!builtins) return false; uint32_t codept = utf8_to_codepoint(ch); if (codept > 0x00FFFF) return false; @@ -167,20 +165,18 @@ __device__ inline bool reclass_device::is_match(char32_t ch, const uint8_t* code return false; } -__device__ inline reinst* reprog_device::get_inst(int32_t idx) const +__device__ __forceinline__ reinst reprog_device::get_inst(int32_t id) const { return _insts[id]; } + +__device__ __forceinline__ reclass_device reprog_device::get_class(int32_t id) const { - assert((idx >= 0) && (idx < _insts_count)); - return _insts + idx; + return _classes[id]; } -__device__ inline reclass_device reprog_device::get_class(int32_t idx) const +__device__ __forceinline__ bool reprog_device::is_empty() const { - assert((idx >= 0) && (idx < _classes_count)); - return _classes[idx]; + return insts_counts() == 0 || get_inst(0).type == END; } -__device__ inline int32_t* reprog_device::startinst_ids() const { return _startinst_ids; } - /** * @brief Evaluate a specific string against regex pattern compiled to this instance. * @@ -195,35 +191,36 @@ __device__ inline int32_t* reprog_device::startinst_ids() const { return _starti * @param group_id Index of the group to match in a multi-group regex pattern. * @return >0 if match found */ -__device__ inline int32_t reprog_device::regexec( - string_view const& dstr, reljunk& jnk, int32_t& begin, int32_t& end, int32_t group_id) +__device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr, + reljunk jnk, + cudf::size_type& begin, + cudf::size_type& end, + cudf::size_type const group_id) const { - int32_t match = 0; - auto checkstart = jnk.starttype; - auto pos = begin; - auto eos = end; - char32_t c = 0; - auto last_character = false; + int32_t match = 0; + auto pos = begin; + auto eos = end; + char_utf8 c = 0; + auto checkstart = jnk.starttype != 0; + auto last_character = false; + string_view::const_iterator itr = string_view::const_iterator(dstr, pos); jnk.list1->reset(); do { - /* fast check for first char */ + // fast check for first CHAR or BOL if (checkstart) { + auto startchar = static_cast(jnk.startchar); switch (jnk.starttype) { - case CHAR: { - auto fidx = dstr.find(static_cast(jnk.startchar), pos); - if (fidx < 0) return match; - pos = fidx; - break; - } - case BOL: { + case BOL: if (pos == 0) break; - if (jnk.startchar != '^') return match; + if (jnk.startchar != '^') { return match; } --pos; - int fidx = dstr.find(static_cast('\n'), pos); - if (fidx < 0) return match; // update begin/end values? - pos = fidx + 1; + startchar = static_cast('\n'); + case CHAR: { + auto const fidx = dstr.find(startchar, pos); + if (fidx < 0) { return match; } + pos = fidx + (jnk.starttype == BOL); break; } } @@ -231,128 +228,114 @@ __device__ inline int32_t reprog_device::regexec( } if (((eos < 0) || (pos < eos)) && match == 0) { - int32_t i = 0; - auto ids = startinst_ids(); - while (ids[i] >= 0) - jnk.list1->activate(ids[i++], (group_id == 0 ? pos : -1), -1); + auto ids = _startinst_ids; + while (*ids >= 0) + jnk.list1->activate(*ids++, (group_id == 0 ? pos : -1), -1); } - last_character = (pos >= dstr.length()); + last_character = itr.byte_offset() >= dstr.size_bytes(); - c = static_cast(last_character ? 0 : *itr); + c = last_character ? 0 : *itr; - // expand LBRA, RBRA, BOL, EOL, BOW, NBOW, and OR + // expand the non-character types like: LBRA, RBRA, BOL, EOL, BOW, NBOW, and OR bool expanded = false; do { jnk.list2->reset(); expanded = false; - for (int16_t i = 0; i < jnk.list1->size; i++) { - auto inst_id = static_cast(jnk.list1->inst_ids[i]); - int2& range = jnk.list1->ranges[i]; - const reinst* inst = get_inst(inst_id); + for (int16_t i = 0; i < jnk.list1->get_size(); i++) { + auto state = jnk.list1->get_state(i); + auto range = state.range; + auto const inst = get_inst(state.inst_id); int32_t id_activate = -1; - switch (inst->type) { + switch (inst.type) { case CHAR: case ANY: case ANYNL: case CCLASS: case NCCLASS: - case END: id_activate = inst_id; break; + case END: id_activate = state.inst_id; break; case LBRA: - if (inst->u1.subid == group_id) range.x = pos; - id_activate = inst->u2.next_id; + if (inst.u1.subid == group_id) range.x = pos; + id_activate = inst.u2.next_id; expanded = true; break; case RBRA: - if (inst->u1.subid == group_id) range.y = pos; - id_activate = inst->u2.next_id; + if (inst.u1.subid == group_id) range.y = pos; + id_activate = inst.u2.next_id; expanded = true; break; case BOL: - if ((pos == 0) || - ((inst->u1.c == '^') && (dstr[pos - 1] == static_cast('\n')))) { - id_activate = inst->u2.next_id; + if ((pos == 0) || ((inst.u1.c == '^') && (dstr[pos - 1] == '\n'))) { + id_activate = inst.u2.next_id; expanded = true; } break; case EOL: - if (last_character || (c == '\n' && inst->u1.c == '$')) { - id_activate = inst->u2.next_id; - expanded = true; - } - break; - case BOW: { - auto codept = utf8_to_codepoint(c); - auto last_c = static_cast(pos ? dstr[pos - 1] : 0); - auto last_codept = utf8_to_codepoint(last_c); - bool cur_alphaNumeric = (codept < 0x010000) && IS_ALPHANUM(_codepoint_flags[codept]); - bool last_alphaNumeric = - (last_codept < 0x010000) && IS_ALPHANUM(_codepoint_flags[last_codept]); - if (cur_alphaNumeric != last_alphaNumeric) { - id_activate = inst->u2.next_id; + if (last_character || (c == '\n' && inst.u1.c == '$')) { + id_activate = inst.u2.next_id; expanded = true; } break; - } + case BOW: case NBOW: { - auto codept = utf8_to_codepoint(c); - auto last_c = static_cast(pos ? dstr[pos - 1] : 0); - auto last_codept = utf8_to_codepoint(last_c); - bool cur_alphaNumeric = (codept < 0x010000) && IS_ALPHANUM(_codepoint_flags[codept]); - bool last_alphaNumeric = + auto const codept = utf8_to_codepoint(c); + auto const last_c = pos > 0 ? dstr[pos - 1] : 0; + auto const last_codept = utf8_to_codepoint(last_c); + + bool const cur_alphaNumeric = + (codept < 0x010000) && IS_ALPHANUM(_codepoint_flags[codept]); + bool const last_alphaNumeric = (last_codept < 0x010000) && IS_ALPHANUM(_codepoint_flags[last_codept]); - if (cur_alphaNumeric == last_alphaNumeric) { - id_activate = inst->u2.next_id; + if ((cur_alphaNumeric == last_alphaNumeric) != (inst.type == BOW)) { + id_activate = inst.u2.next_id; expanded = true; } break; } case OR: - jnk.list2->activate(inst->u1.right_id, range.x, range.y); - id_activate = inst->u2.left_id; + jnk.list2->activate(inst.u1.right_id, range.x, range.y); + id_activate = inst.u2.left_id; expanded = true; break; } if (id_activate >= 0) jnk.list2->activate(id_activate, range.x, range.y); } - swaplist(jnk.list1, jnk.list2); + jnk.swaplist(); } while (expanded); - // execute + // execute instructions bool continue_execute = true; jnk.list2->reset(); - for (int16_t i = 0; continue_execute && i < jnk.list1->size; i++) { - auto inst_id = static_cast(jnk.list1->inst_ids[i]); - int2& range = jnk.list1->ranges[i]; - const reinst* inst = get_inst(inst_id); + for (int16_t i = 0; continue_execute && i < jnk.list1->get_size(); i++) { + auto const state = jnk.list1->get_state(i); + auto const range = state.range; + auto const inst = get_inst(state.inst_id); int32_t id_activate = -1; - switch (inst->type) { + switch (inst.type) { case CHAR: - if (inst->u1.c == c) id_activate = inst->u2.next_id; + if (inst.u1.c == c) id_activate = inst.u2.next_id; break; case ANY: - if (c != '\n') id_activate = inst->u2.next_id; + if (c != '\n') id_activate = inst.u2.next_id; break; - case ANYNL: id_activate = inst->u2.next_id; break; + case ANYNL: id_activate = inst.u2.next_id; break; + case NCCLASS: case CCLASS: { - reclass_device cls = get_class(inst->u1.cls_id); - if (cls.is_match(c, _codepoint_flags)) id_activate = inst->u2.next_id; - break; - } - case NCCLASS: { - reclass_device cls = get_class(inst->u1.cls_id); - if (!cls.is_match(c, _codepoint_flags)) id_activate = inst->u2.next_id; + auto const cls = get_class(inst.u1.cls_id); + if (cls.is_match(static_cast(c), _codepoint_flags) == (inst.type == CCLASS)) { + id_activate = inst.u2.next_id; + } break; } case END: match = 1; begin = range.x; end = group_id == 0 ? pos : range.y; - + // done with execute continue_execute = false; break; } @@ -362,18 +345,18 @@ __device__ inline int32_t reprog_device::regexec( ++pos; ++itr; - swaplist(jnk.list1, jnk.list2); - checkstart = jnk.list1->size > 0 ? 0 : 1; - } while (!last_character && (jnk.list1->size > 0 || match == 0)); + jnk.swaplist(); + checkstart = jnk.list1->get_size() == 0; + } while (!last_character && (!checkstart || !match)); return match; } template -__device__ inline int32_t reprog_device::find(int32_t idx, - string_view const& dstr, - int32_t& begin, - int32_t& end) +__device__ __forceinline__ int32_t reprog_device::find(int32_t idx, + string_view const dstr, + cudf::size_type& begin, + cudf::size_type& end) const { int32_t rtn = call_regexec(idx, dstr, begin, end); if (rtn <= 0) begin = end = -1; @@ -381,11 +364,11 @@ __device__ inline int32_t reprog_device::find(int32_t idx, } template -__device__ inline match_result reprog_device::extract(cudf::size_type idx, - string_view const& dstr, - cudf::size_type begin, - cudf::size_type end, - cudf::size_type group_id) +__device__ __forceinline__ match_result reprog_device::extract(cudf::size_type idx, + string_view const dstr, + cudf::size_type begin, + cudf::size_type end, + cudf::size_type const group_id) const { end = begin + 1; return call_regexec(idx, dstr, begin, end, group_id + 1) > 0 @@ -394,28 +377,29 @@ __device__ inline match_result reprog_device::extract(cudf::size_type idx, } template -__device__ inline int32_t reprog_device::call_regexec( - int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id) +__device__ __forceinline__ int32_t reprog_device::call_regexec(int32_t idx, + string_view const dstr, + cudf::size_type& begin, + cudf::size_type& end, + cudf::size_type const group_id) const { u_char data1[stack_size], data2[stack_size]; - auto const stype = get_inst(_startinst_id)->type; - auto const schar = get_inst(_startinst_id)->u1.c; - relist list1(static_cast(_insts_count), data1); relist list2(static_cast(_insts_count), data2); - reljunk jnk(&list1, &list2, stype, schar); + reljunk jnk(&list1, &list2, get_inst(_startinst_id)); return regexec(dstr, jnk, begin, end, group_id); } template <> -__device__ inline int32_t reprog_device::call_regexec( - int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id) +__device__ __forceinline__ int32_t +reprog_device::call_regexec(int32_t idx, + string_view const dstr, + cudf::size_type& begin, + cudf::size_type& end, + cudf::size_type const group_id) const { - auto const stype = get_inst(_startinst_id)->type; - auto const schar = get_inst(_startinst_id)->u1.c; - auto const relists_size = relist::alloc_size(_insts_count); auto* listmem = reinterpret_cast(_relists_mem); // beginning of relist buffer; listmem += (idx * relists_size * 2); // two relist ptrs in reljunk: @@ -423,7 +407,7 @@ __device__ inline int32_t reprog_device::call_regexec( auto* list1 = new (listmem) relist(static_cast(_insts_count)); auto* list2 = new (listmem + relists_size) relist(static_cast(_insts_count)); - reljunk jnk(list1, list2, stype, schar); + reljunk jnk(list1, list2, get_inst(_startinst_id)); return regexec(dstr, jnk, begin, end, group_id); } diff --git a/cpp/src/strings/regex/regexec.cu b/cpp/src/strings/regex/regexec.cu index 3bcf55cf069..70d6079972a 100644 --- a/cpp/src/strings/regex/regexec.cu +++ b/cpp/src/strings/regex/regexec.cu @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -27,39 +28,6 @@ namespace cudf { namespace strings { namespace detail { -namespace { -/** - * @brief Converts UTF-8 string into fixed-width 32-bit character vector. - * - * No character conversion occurs. - * Each UTF-8 character is promoted into a 32-bit value. - * The last entry in the returned vector will be a 0 value. - * The fixed-width vector makes it easier to compile and faster to execute. - * - * @param pattern Regular expression encoded with UTF-8. - * @return Fixed-width 32-bit character vector. - */ -std::vector string_to_char32_vector(std::string const& pattern) -{ - size_type size = static_cast(pattern.size()); - size_type count = std::count_if(pattern.cbegin(), pattern.cend(), [](char ch) { - return is_begin_utf8_char(static_cast(ch)); - }); - std::vector result(count + 1); - char32_t* output_ptr = result.data(); - const char* input_ptr = pattern.data(); - for (size_type idx = 0; idx < size; ++idx) { - char_utf8 output_character = 0; - size_type ch_width = to_char_utf8(input_ptr, output_character); - input_ptr += ch_width; - idx += ch_width - 1; - *output_ptr++ = output_character; - } - result[count] = 0; // last entry set to 0 - return result; -} - -} // namespace // Copy reprog primitive values reprog_device::reprog_device(reprog& prog) @@ -89,75 +57,76 @@ std::unique_ptr> reprog_devic size_type strings_count, rmm::cuda_stream_view stream) { - std::vector pattern32 = string_to_char32_vector(pattern); // compile pattern into host object - reprog h_prog = reprog::create_from(pattern32.data(), flags); + reprog h_prog = reprog::create_from(pattern, flags); + // compute size to hold all the member data - auto insts_count = h_prog.insts_count(); - auto classes_count = h_prog.classes_count(); - auto starts_count = h_prog.starts_count(); - // compute size of each section; make sure each is aligned appropriately - auto insts_size = - cudf::util::round_up_safe(insts_count * sizeof(_insts[0]), sizeof(size_t)); - auto startids_size = - cudf::util::round_up_safe(starts_count * sizeof(_startinst_ids[0]), sizeof(size_t)); - auto classes_size = - cudf::util::round_up_safe(classes_count * sizeof(_classes[0]), sizeof(size_t)); - for (int32_t idx = 0; idx < classes_count; ++idx) + auto const insts_count = h_prog.insts_count(); + auto const classes_count = h_prog.classes_count(); + auto const starts_count = h_prog.starts_count(); + + // compute size of each section + auto insts_size = insts_count * sizeof(_insts[0]); + auto startids_size = starts_count * sizeof(_startinst_ids[0]); + auto classes_size = classes_count * sizeof(_classes[0]); + for (auto idx = 0; idx < classes_count; ++idx) classes_size += static_cast((h_prog.class_at(idx).literals.size()) * sizeof(char32_t)); - size_t memsize = insts_size + startids_size + classes_size; - size_t rlm_size = 0; - // check memory size needed for executing regex - if (insts_count > RX_LARGE_INSTS) { - auto relist_alloc_size = relist::alloc_size(insts_count); - rlm_size = relist_alloc_size * 2L * strings_count; // reljunk has 2 relist ptrs - } + // make sure each section is aligned for the subsequent section's data type + auto const memsize = cudf::util::round_up_safe(insts_size, sizeof(_startinst_ids[0])) + + cudf::util::round_up_safe(startids_size, sizeof(_classes[0])) + + cudf::util::round_up_safe(classes_size, sizeof(char32_t)); + + // allocate memory to store all the prog data in a flat contiguous buffer + std::vector h_buffer(memsize); // copy everything into here; + auto h_ptr = h_buffer.data(); // this is our running host ptr; + auto d_buffer = new rmm::device_buffer(memsize, stream); // output device memory; + auto d_ptr = reinterpret_cast(d_buffer->data()); // running device pointer - // allocate memory to store prog data - std::vector h_buffer(memsize); - u_char* h_ptr = h_buffer.data(); // running pointer - auto* d_buffer = new rmm::device_buffer(memsize, stream); - u_char* d_ptr = reinterpret_cast(d_buffer->data()); // running device pointer // put everything into a flat host buffer first reprog_device* d_prog = new reprog_device(h_prog); - // copy the instructions array first (fixed-size structs) - reinst* insts = reinterpret_cast(h_ptr); - memcpy(insts, h_prog.insts_data(), insts_size); - h_ptr += insts_size; // next section + + // copy the instructions array first (fixed-sized structs) + memcpy(h_ptr, h_prog.insts_data(), insts_size); d_prog->_insts = reinterpret_cast(d_ptr); + + // point to the end for the next section + insts_size = cudf::util::round_up_safe(insts_size, sizeof(_startinst_ids[0])); + h_ptr += insts_size; d_ptr += insts_size; - // copy the startinst_ids next (ints) - int32_t* startinst_ids = reinterpret_cast(h_ptr); - memcpy(startinst_ids, h_prog.starts_data(), startids_size); - h_ptr += startids_size; // next section + // copy the startinst_ids next + memcpy(h_ptr, h_prog.starts_data(), startids_size); d_prog->_startinst_ids = reinterpret_cast(d_ptr); + + // next section; align the size for next data type + startids_size = cudf::util::round_up_safe(startids_size, sizeof(_classes[0])); + h_ptr += startids_size; d_ptr += startids_size; // copy classes into flat memory: [class1,class2,...][char32 arrays] - reclass_device* classes = reinterpret_cast(h_ptr); - d_prog->_classes = reinterpret_cast(d_ptr); + auto classes = reinterpret_cast(h_ptr); + d_prog->_classes = reinterpret_cast(d_ptr); // get pointer to the end to handle variable length data - u_char* h_end = h_ptr + (classes_count * sizeof(reclass_device)); - u_char* d_end = d_ptr + (classes_count * sizeof(reclass_device)); + auto h_end = h_ptr + (classes_count * sizeof(reclass_device)); + auto d_end = d_ptr + (classes_count * sizeof(reclass_device)); // place each class and append the variable length data for (int32_t idx = 0; idx < classes_count; ++idx) { reclass& h_class = h_prog.class_at(idx); - reclass_device d_class; - d_class.builtins = h_class.builtins; - d_class.count = h_class.literals.size() / 2; - d_class.literals = reinterpret_cast(d_end); - memcpy(classes++, &d_class, sizeof(d_class)); + reclass_device d_class{h_class.builtins, + static_cast(h_class.literals.size() / 2), + reinterpret_cast(d_end)}; + *classes++ = d_class; memcpy(h_end, h_class.literals.c_str(), h_class.literals.size() * sizeof(char32_t)); h_end += h_class.literals.size() * sizeof(char32_t); d_end += h_class.literals.size() * sizeof(char32_t); } + // initialize the rest of the elements - d_prog->_insts_count = insts_count; - d_prog->_starts_count = starts_count; - d_prog->_classes_count = classes_count; d_prog->_codepoint_flags = codepoint_flags; + // allocate execute memory if needed rmm::device_buffer* d_relists{}; - if (rlm_size > 0) { + if (insts_count > RX_LARGE_INSTS) { + // two relist state structures are needed for execute per string + auto const rlm_size = relist::alloc_size(insts_count) * 2 * strings_count; d_relists = new rmm::device_buffer(rlm_size, stream); d_prog->_relists_mem = d_relists->data(); } @@ -165,7 +134,8 @@ std::unique_ptr> reprog_devic // copy flat prog to device memory CUDF_CUDA_TRY(cudaMemcpyAsync( d_buffer->data(), h_buffer.data(), memsize, cudaMemcpyHostToDevice, stream.value())); - // + + // build deleter to cleanup device memory auto deleter = [d_buffer, d_relists](reprog_device* t) { t->destroy(); delete d_buffer;