diff --git a/CHANGELOG.md b/CHANGELOG.md index e1053a84b19..0f67ae3e01c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -205,6 +205,7 @@ - PR #3314 Drop `cython` from run requirements - PR #3301 Add tests for empty column wrapper. - PR #3294 Update to arrow-cpp and pyarrow 0.15.1 +- PR #3292 Port NVStrings regex contains function - PR #3310 Add `row_hasher` and `element_hasher` utilities - PR #3272 Support non-default streams when creating/destroying hash maps - PR #3286 Clean up the starter code on README diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 91a6b293439..b4a7210d5a2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -549,6 +549,7 @@ add_library(cudf src/strings/case.cu src/strings/char_types/char_types.cu src/strings/combine.cu + src/strings/contains.cu src/strings/convert/convert_booleans.cu src/strings/convert/convert_datetime.cu src/strings/convert/convert_floats.cu @@ -562,6 +563,8 @@ add_library(cudf src/strings/find.cu src/strings/find_multiple.cu src/strings/padding.cu + src/strings/regex/regcomp.cpp + src/strings/regex/regexec.cu src/strings/replace/replace.cu src/strings/sorting/sorting.cu src/strings/split/split.cu diff --git a/cpp/include/cudf/strings/contains.hpp b/cpp/include/cudf/strings/contains.hpp new file mode 100644 index 00000000000..d4510426008 --- /dev/null +++ b/cpp/include/cudf/strings/contains.hpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace cudf +{ +namespace strings +{ + +/** + * @brief Returns a boolean column identifying rows which + * match the given regex pattern. + * + * ``` + * s = ["abc","123","def456"] + * r = contains(s,"\\d+") + * r is now [false, true, true] + * ``` + * + * Any null string entries return corresponding null output column entries. + * + * @param strings Strings instance for this operation. + * @param pattern Regex pattern to match to each string. + * @param mr Resource for allocating device memory. + * @return New column of boolean results for each string. + */ +std::unique_ptr contains_re( strings_column_view const& strings, + std::string const& pattern, + rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource()); + +/** + * @brief Returns a boolean column identifying rows which + * matching the given regex pattern but only at the beginning the string. + * + * ``` + * s = ["abc","123","def456"] + * r = contains(s,"\\d+") + * r is now [false, true, false] + * ``` + * + * Any null string entries return corresponding null output column entries. + * + * @param strings Strings instance for this operation. + * @param pattern Regex pattern to match to each string. + * @param mr Resource for allocating device memory. + * @return New column of boolean results for each string. + */ +std::unique_ptr matches_re( strings_column_view const& strings, + std::string const& pattern, + rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource()); + +/** + * @brief Returns the number of times the given regex pattern + * matches in each string. + * + * ``` + * s = ["abc","123","def45"] + * r = contains(s,"\\d") + * r is now [0, 3, 2] + * ``` + * + * Any null string entries return corresponding null output column entries. + * + * @param strings Strings instance for this operation. + * @param pattern Regex pattern to match within each string. + * @param mr Resource for allocating device memory. + * @return New INT32 column with counts for each string. + */ +std::unique_ptr count_re( strings_column_view const& strings, + std::string const& pattern, + rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource()); + +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/char_types/is_flags.h b/cpp/src/strings/char_types/is_flags.h index 0bdc6bd02a7..56f60742ca7 100644 --- a/cpp/src/strings/char_types/is_flags.h +++ b/cpp/src/strings/char_types/is_flags.h @@ -17,13 +17,12 @@ // // 8-bit flag for each code-point. -// Flags for each character are defined in char_flags.h // -#define IS_DECIMAL(x) (x & 1) -#define IS_NUMERIC(x) (x & 2) -#define IS_DIGIT(x) (x & 4) -#define IS_ALPHA(x) (x & 8) -#define IS_ALPHANUM(x) (x & 15) -#define IS_SPACE(x) (x & 16) -#define IS_UPPER(x) (x & 32) -#define IS_LOWER(x) (x & 64) +#define IS_DECIMAL(x) ((x) & (1 << 0)) +#define IS_NUMERIC(x) ((x) & (1 << 1)) +#define IS_DIGIT(x) ((x) & (1 << 2)) +#define IS_ALPHA(x) ((x) & (1 << 3)) +#define IS_SPACE(x) ((x) & (1 << 4)) +#define IS_UPPER(x) ((x) & (1 << 5)) +#define IS_LOWER(x) ((x) & (1 << 6)) +#define IS_ALPHANUM(x) ((x) & (0x0F)) diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu new file mode 100644 index 00000000000..59b254ed2ce --- /dev/null +++ b/cpp/src/strings/contains.cu @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace cudf +{ +namespace strings +{ +namespace detail +{ +namespace +{ + +/** + * @brief This functor handles both contains_re and match_re to minimize the number + * of regex calls to find() to be inlined greatly reducing compile time. + * + * The stack is used to keep progress on evaluating the regex instructions on each string. + * So the size of the stack is in proportion to the number of instructions in the given regex pattern. + * + * There are three call types based on the number of regex instructions in the given pattern. + * Small to medium instruction lengths can use the stack effectively though smaller executes faster. + * Longer patterns require global memory. + * + */ +template +struct contains_fn +{ + reprog_device prog; + column_device_view d_strings; + bool bmatch{false}; // do not make this a template parameter to keep compile times down + + __device__ cudf::experimental::bool8 operator()(size_type idx) + { + if( d_strings.is_null(idx) ) + return 0; + u_char data1[stack_size], data2[stack_size]; + prog.set_stack_mem(data1,data2); + string_view d_str = d_strings.element(idx); + int32_t begin = 0; + int32_t end = bmatch ? 1 : d_str.length(); // 1=match only the beginning of the string + return static_cast(prog.find(idx,d_str,begin,end)); + } +}; + +// +std::unique_ptr contains_util( strings_column_view const& strings, + std::string const& pattern, + bool beginning_only = false, + rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource(), + cudaStream_t stream = 0) +{ + auto strings_count = strings.size(); + auto strings_column = column_device_view::create(strings.parent(),stream); + auto d_column = *strings_column; + + // compile regex into device object + auto prog = reprog_device::create(pattern,get_character_flags_table(),strings_count,stream); + auto d_prog = *prog; + + // create the output column + auto results = make_numeric_column( data_type{BOOL8}, strings_count, + copy_bitmask( strings.parent(), stream, mr), strings.null_count(), stream, mr); + auto d_results = results->mutable_view().data(); + + // fill the output column + auto execpol = rmm::exec_policy(stream); + int regex_insts = d_prog.insts_counts(); + if( (regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS) ) + thrust::transform(execpol->on(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + d_results, contains_fn{d_prog, d_column, beginning_only} ); + else if( regex_insts <= RX_MEDIUM_INSTS ) + thrust::transform(execpol->on(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + d_results, contains_fn{d_prog, d_column, beginning_only} ); + else + thrust::transform(execpol->on(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + d_results, contains_fn{d_prog, d_column, beginning_only} ); + + results->set_null_count(strings.null_count()); + return results; +} + +} // namespace + +std::unique_ptr contains_re( strings_column_view const& strings, + std::string const& pattern, + rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource(), + cudaStream_t stream = 0) +{ + return contains_util(strings, pattern, false, mr, stream); +} + +std::unique_ptr matches_re( strings_column_view const& strings, + std::string const& pattern, + rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource(), + cudaStream_t stream = 0) +{ + return contains_util(strings, pattern, true, mr, stream); +} + +} // namespace detail + +// external APIs + +std::unique_ptr contains_re( strings_column_view const& strings, + std::string const& pattern, + rmm::mr::device_memory_resource* mr) +{ + return detail::contains_re(strings, pattern, mr); +} + +std::unique_ptr matches_re( strings_column_view const& strings, + std::string const& pattern, + rmm::mr::device_memory_resource* mr) +{ + return detail::matches_re(strings, pattern, mr); +} + +namespace detail +{ + +namespace +{ + +/** + * @brief This counts the number of times the regex pattern matches in each string. + * + */ +template +struct count_fn +{ + reprog_device prog; + column_device_view d_strings; + + __device__ int32_t operator()(unsigned int idx) + { + u_char data1[stack_size], data2[stack_size]; + prog.set_stack_mem(data1,data2); + if( d_strings.is_null(idx) ) + return 0; + string_view d_str = d_strings.element(idx); + int32_t find_count = 0; + size_type nchars = d_str.length(); + size_type begin = 0; + while( begin <= nchars ) + { + auto end = nchars; + if( prog.find(idx,d_str,begin,end) <=0 ) + break; + ++find_count; + begin = end > begin ? end : begin + 1; + } + return find_count; + } +}; + +} + +std::unique_ptr count_re( strings_column_view const& strings, + std::string const& pattern, + rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource(), + cudaStream_t stream = 0) +{ + auto strings_count = strings.size(); + auto strings_column = column_device_view::create(strings.parent(),stream); + auto d_column = *strings_column; + + // compile regex into device object + auto prog = reprog_device::create(pattern,get_character_flags_table(),strings_count,stream); + auto d_prog = *prog; + + // create the output column + auto results = make_numeric_column( data_type{INT32}, strings_count, + copy_bitmask( strings.parent(), stream, mr), strings.null_count(), stream, mr); + auto d_results = results->mutable_view().data(); + + // fill the output column + auto execpol = rmm::exec_policy(stream); + int regex_insts = d_prog.insts_counts(); + if( (regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS) ) + thrust::transform(execpol->on(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + d_results, count_fn{d_prog, d_column} ); + else if( regex_insts <= RX_MEDIUM_INSTS ) + thrust::transform(execpol->on(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + d_results, count_fn{d_prog, d_column} ); + else + thrust::transform(execpol->on(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + d_results, count_fn{d_prog, d_column} ); + + results->set_null_count(strings.null_count()); + return results; + +} + +} // namespace detail + +// external API + +std::unique_ptr count_re( strings_column_view const& strings, + std::string const& pattern, + rmm::mr::device_memory_resource* mr) +{ + return detail::count_re(strings, pattern, mr); +} + +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp new file mode 100644 index 00000000000..80d6a589ebe --- /dev/null +++ b/cpp/src/strings/regex/regcomp.cpp @@ -0,0 +1,1178 @@ +/* +* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include + +#include + + +namespace cudf +{ +namespace strings +{ +namespace detail +{ +namespace +{ + +// Bitmask of all operators +#define OPERATOR_MASK 0200 +enum OperatorType +{ + START = 0200, // Start, used for marker on stack + LBRA_NC = 0203, // non-capturing group + CAT = 0205, // Concatentation, implicit operator + STAR = 0206, // Closure, * + STAR_LAZY = 0207, + PLUS = 0210, // a+ == aa* + PLUS_LAZY = 0211, + QUEST = 0212, // a? == a|nothing, i.e. 0 or 1 a's + QUEST_LAZY = 0213, + COUNTED = 0214, // counted repeat a{2} a{3,5} + COUNTED_LAZY = 0215, + NOP = 0302, // No operation, internal use only +}; + +static reclass ccls_w(1); // [a-z], [A-Z], [0-9], and '_' +static reclass ccls_W(8); // now ccls_w plus '\n' +static reclass ccls_s(2); // all spaces or ctrl characters +static reclass ccls_S(16); // not ccls_s +static reclass ccls_d(4); // digits [0-9] +static reclass ccls_D(32); // not ccls_d plus '\n' + +} // namespace + +int32_t reprog::add_inst(int32_t t) +{ + reinst inst; + inst.type = t; + inst.u2.left_id = 0; + inst.u1.right_id = 0; + return add_inst(inst); +} + +int32_t reprog::add_inst(reinst inst) +{ + _insts.push_back(inst); + return static_cast(_insts.size() - 1); +} + +int32_t reprog::add_class(reclass cls) +{ + _classes.push_back(cls); + return static_cast(_classes.size()-1); +} + +reinst& reprog::inst_at(int32_t id) +{ + return _insts[id]; +} + +reclass& reprog::class_at(int32_t id) +{ + return _classes[id]; +} + +void reprog::set_start_inst(int32_t id) +{ + _startinst_id = id; +} + +int32_t reprog::get_start_inst() const +{ + return _startinst_id; +} + +int32_t reprog::insts_count() const +{ + return static_cast(_insts.size()); +} + +int32_t reprog::classes_count() const +{ + return static_cast(_classes.size()); +} + +void reprog::set_groups_count(int32_t groups) +{ + _num_capturing_groups = groups; +} + +int32_t reprog::groups_count() const +{ + return _num_capturing_groups; +} + +const reinst* reprog::insts_data() const +{ + return _insts.data(); +} + +const int32_t* reprog::starts_data() const +{ + return _startinst_ids.data(); +} + +int32_t reprog::starts_count() const +{ + return static_cast(_startinst_ids.size()); +} + +// Converts pattern into regex classes +class regex_parser +{ + reprog& m_prog; + + const char32_t* exprp; + bool lexdone; + + int id_ccls_w = -1; // alphanumeric + int id_ccls_W = -1; // not alphanumeric + int id_ccls_s = -1; // space + int id_ccls_d = -1; // digit + int id_ccls_D = -1; // not digit + + char32_t yy; /* last lex'd Char */ + int yyclass_id; /* last lex'd class */ + short yy_min_count; + short yy_max_count; + + bool nextc(char32_t& c) // return "quoted" == backslash-escape prefix + { + if(lexdone) + { + c = 0; + return true; + } + c = *exprp++; + if(c == '\\') + { + c = *exprp++; + return true; + } + if(c == 0) + lexdone = true; + return false; + } + + int bldcclass() + { + int type = CCLASS; + std::vector cls; + int builtins = 0; + + /* look ahead for negation */ + /* SPECIAL CASE!!! negated classes don't match \n */ + char32_t c = 0; + int quoted = nextc(c); + if(!quoted && c == '^') + { + type = NCCLASS; + quoted = nextc(c); + cls.push_back('\n'); + cls.push_back('\n'); + } + + /* parse class into a set of spans */ + int count_char = 0; + while(true) + { + count_char++; + if(c == 0) + { + // malformed '[]' + return 0; + } + if(quoted) + { + switch(c) + { + case 'n': + c = '\n'; + break; + case 'r': + c = '\r'; + break; + case 't': + c = '\t'; + break; + case 'a': + c = 0x07; + break; + case 'b': + c = 0x08; + break; + case 'f': + c = 0x0C; + break; + case 'w': + builtins |= ccls_w.builtins; + quoted = nextc(c); + continue; + case 's': + builtins |= ccls_s.builtins; + quoted = nextc(c); + continue; + case 'd': + builtins |= ccls_d.builtins; + quoted = nextc(c); + continue; + case 'W': + builtins |= ccls_W.builtins; + quoted = nextc(c); + continue; + case 'S': + builtins |= ccls_S.builtins; + quoted = nextc(c); + continue; + case 'D': + builtins |= ccls_D.builtins; + quoted = nextc(c); + continue; + } + } + if(!quoted && c == ']' && count_char>1) + break; + if(!quoted && c == '-') + { + if (cls.size() < 1) + { + // malformed '[]' + return 0; + } + quoted = nextc(c); + if ((!quoted && c == ']') || c == 0) + { + // malformed '[]' + return 0; + } + cls[cls.size() - 1] = c; + } + else + { + cls.push_back(c); + cls.push_back(c); + } + quoted = nextc(c); + } + + /* sort on span start */ + for (int p = 0; p < cls.size(); p += 2) + for (int np = p + 2; np < cls.size(); np+=2) + if (cls[np] < cls[p]) + { + c = cls[np]; + cls[np] = cls[p]; + cls[p] = c; + c = cls[np+1]; + cls[np+1] = cls[p+1]; + cls[p+1] = c; + + } + + /* merge spans */ + reclass yycls{builtins}; + if( cls.size()>=2 ) + { + int np = 0; + int p = 0; + yycls.literals += cls[p++]; + yycls.literals += cls[p++]; + for (; p < cls.size(); p += 2) + { + /* overlapping or adjacent ranges? */ + if (cls[p] <= yycls.literals[np + 1] + 1) + { + if (cls[p + 1] >= yycls.literals[np + 1]) + yycls.literals.replace(np + 1, 1, 1, cls[p + 1]); /* coalesce */ + } + else + { + np += 2; + yycls.literals += cls[p]; + yycls.literals += cls[p+1]; + } + } + } + yyclass_id = m_prog.add_class(yycls); + return type; + } + + int lex(int dot_type) + { + int quoted = nextc(yy); + if(quoted) + { + if (yy == 0) + return END; + // treating all quoted numbers as Octal, since we are not supporting backreferences + if (yy >= '0' && yy <= '7') + { + yy = yy - '0'; + char32_t c = *exprp++; + while( c >= '0' && c <= '7' ) + { + yy = (yy << 3) | (c - '0'); + c = *exprp++; + } + return CHAR; + } + else + { + switch (yy) + { + case 't': + yy = '\t'; + break; + case 'n': + yy = '\n'; + break; + case 'r': + yy = '\r'; + break; + case 'a': + yy = 0x07; + break; + case 'f': + yy = 0x0C; + break; + case '0': + yy = 0; + break; + case 'x': + { + char32_t a = *exprp++; + char32_t b = *exprp++; + yy = 0; + if (a >= '0' && a <= '9') yy += (a - '0') << 4; + else if (a > 'a' && a <= 'f') yy += (a - 'a' + 10) << 4; + else if (a > 'A' && a <= 'F') yy += (a - 'A' + 10) << 4; + if (b >= '0' && b <= '9') yy += b - '0'; + else if (b > 'a' && b <= 'f') yy += b - 'a' + 10; + else if (b > 'A' && b <= 'F') yy += b - 'A' + 10; + break; + } + case 'w': + { + if (id_ccls_w < 0) + { + yyclass_id = m_prog.add_class(ccls_w); + id_ccls_w = yyclass_id; + } + else yyclass_id = id_ccls_w; + return CCLASS; + } + case 'W': + { + if (id_ccls_W < 0) + { + reclass cls = ccls_w; + cls.literals += '\n'; + cls.literals += '\n'; + yyclass_id = m_prog.add_class(cls); + id_ccls_W = yyclass_id; + } + else yyclass_id = id_ccls_W; + return NCCLASS; + } + case 's': + { + if (id_ccls_s < 0) + { + yyclass_id = m_prog.add_class(ccls_s); + id_ccls_s = yyclass_id; + } + else yyclass_id = id_ccls_s; + return CCLASS; + } + case 'S': + { + if (id_ccls_s < 0) + { + yyclass_id = m_prog.add_class(ccls_s); + id_ccls_s = yyclass_id; + } + else yyclass_id = id_ccls_s; + return NCCLASS; + } + case 'd': + { + if (id_ccls_d < 0) + { + yyclass_id = m_prog.add_class(ccls_d); + id_ccls_d = yyclass_id; + } + else yyclass_id = id_ccls_d; + return CCLASS; + } + case 'D': + { + if (id_ccls_D < 0) + { + reclass cls = ccls_d; + cls.literals += '\n'; + cls.literals += '\n'; + yyclass_id = m_prog.add_class(cls); + id_ccls_D = yyclass_id; + } + else yyclass_id = id_ccls_D; + return NCCLASS; + } + case 'b': + return BOW; + case 'B': + return NBOW; + case 'A': + return BOL; + case 'Z': + return EOL; + } + return CHAR; + } + } + + switch(yy) + { + case 0: + return END; + case '*': + if (*exprp == '?') + { + exprp++; + return STAR_LAZY; + } + return STAR; + case '?': + if (*exprp == '?') + { + exprp++; + return QUEST_LAZY; + } + return QUEST; + case '+': + if (*exprp == '?') + { + exprp++; + return PLUS_LAZY; + } + return PLUS; + case '{': // counted repitition + { + if (*exprp<'0' || *exprp>'9') break; + const char32_t* exprp_backup = exprp; // in case '}' is not found + char buff[8] = {0}; + for (int i = 0; i < 7 && *exprp != '}' && *exprp != ',' && *exprp != 0; i++, exprp++) + { + buff[i] = *exprp; + buff[i + 1] = 0; + } + if (*exprp != '}' && *exprp != ',') + { + exprp = exprp_backup; + break; + } + sscanf(buff, "%hd", &yy_min_count); + if (*exprp != ',') + yy_max_count = yy_min_count; + else + { + yy_max_count = -1; + exprp++; + buff[0] = 0; + for (int i = 0; i < 7 && *exprp != '}' && *exprp != 0; i++, exprp++) + { + buff[i] = *exprp; + buff[i + 1] = 0; + } + if (*exprp != '}') + { + exprp = exprp_backup; + break; + } + if (buff[0] != 0) + sscanf(buff, "%hd", &yy_max_count); + } + exprp++; + if (*exprp == '?') + { + exprp++; + return COUNTED_LAZY; + } + return COUNTED; + } + case '|': + return OR; + case '.': + return dot_type; + case '(': + if (*exprp == '?' && *(exprp + 1) == ':') // non-capturing group + { + exprp += 2; + return LBRA_NC; + } + return LBRA; + case ')': + return RBRA; + case '^': + return BOL; + case '$': + return EOL; + case '[': + return bldcclass(); + } + return CHAR; + } +public: + struct Item + { + int t; + union + { + char32_t yy; + int yyclass_id; + struct + { + short n; + short m; + } yycount; + } d; + }; + std::vector m_items; + + bool m_has_counted; + + regex_parser(const char32_t* pattern, int dot_type, reprog& prog) + : m_prog(prog), exprp(pattern), lexdone(false), m_has_counted(false) + { + int token = 0; + while((token = lex(dot_type)) != END) + { + Item item; + item.t = token; + if (token == CCLASS || token == NCCLASS) + item.d.yyclass_id = yyclass_id; + else if (token == COUNTED || token == COUNTED_LAZY) + { + item.d.yycount.n = yy_min_count; + item.d.yycount.m = yy_max_count; + m_has_counted = true; + } + else + item.d.yy = yy; + m_items.push_back(item); + } + } +}; + +/** + * @brief The compiler converts class list into instructions. + */ +class regex_compiler +{ + reprog& m_prog; + + struct Node + { + int id_first; + int id_last; + }; + + int cursubid; + int pushsubid; + std::vector andstack; + + struct Ator + { + int t; + int subid; + }; + + std::vector atorstack; + + bool lastwasand; + int nbra; + + inline void pushand(int f, int l) + { + andstack.push_back({ f, l }); + } + + inline Node popand(int op) + { + if( andstack.size() < 1 ) + { + //missing operand for op + int inst_id = m_prog.add_inst(NOP); + pushand(inst_id, inst_id); + } + Node node = andstack[andstack.size() - 1]; + andstack.pop_back(); + return node; + } + + inline void pushator(int t) + { + Ator ator; + ator.t = t; + ator.subid = pushsubid; + atorstack.push_back(ator); + } + + inline Ator popator() + { + Ator ator = atorstack[atorstack.size() - 1]; + atorstack.pop_back(); + return ator; + } + + void evaluntil(int pri) + { + Node op1; + Node op2; + int id_inst1 = -1; + int id_inst2 = -1; + while( pri == RBRA || atorstack[atorstack.size() - 1].t >= pri ) + { + Ator ator = popator(); + switch(ator.t) + { + default: + // unknown operator in evaluntil + break; + case LBRA: /* must have been RBRA */ + op1 = popand('('); + id_inst2 = m_prog.add_inst(RBRA); + m_prog.inst_at(id_inst2).u1.subid = ator.subid;//subidstack[subidstack.size()-1]; + m_prog.inst_at(op1.id_last).u2.next_id = id_inst2; + id_inst1 = m_prog.add_inst(LBRA); + m_prog.inst_at(id_inst1).u1.subid = ator.subid;//subidstack[subidstack.size() - 1]; + m_prog.inst_at(id_inst1).u2.next_id = op1.id_first; + pushand(id_inst1, id_inst2); + return; + case OR: + op2 = popand('|'); + op1 = popand('|'); + id_inst2 = m_prog.add_inst(NOP); + m_prog.inst_at(op2.id_last).u2.next_id = id_inst2; + m_prog.inst_at(op1.id_last).u2.next_id = id_inst2; + id_inst1 = m_prog.add_inst(OR); + m_prog.inst_at(id_inst1).u1.right_id = op1.id_first; + m_prog.inst_at(id_inst1).u2.left_id = op2.id_first; + pushand(id_inst1, id_inst2); + break; + case CAT: + op2 = popand(0); + op1 = popand(0); + m_prog.inst_at(op1.id_last).u2.next_id = op2.id_first; + pushand(op1.id_first, op2.id_last); + break; + case STAR: + op2 = popand('*'); + id_inst1 = m_prog.add_inst(OR); + m_prog.inst_at(op2.id_last).u2.next_id = id_inst1; + m_prog.inst_at(id_inst1).u1.right_id = op2.id_first; + pushand(id_inst1, id_inst1); + break; + case STAR_LAZY: + op2 = popand('*'); + id_inst1 = m_prog.add_inst(OR); + id_inst2 = m_prog.add_inst(NOP); + m_prog.inst_at(op2.id_last).u2.next_id = id_inst1; + m_prog.inst_at(id_inst1).u2.left_id = op2.id_first; + m_prog.inst_at(id_inst1).u1.right_id = id_inst2; + pushand(id_inst1, id_inst2); + break; + case PLUS: + op2 = popand('+'); + id_inst1 = m_prog.add_inst(OR); + m_prog.inst_at(op2.id_last).u2.next_id = id_inst1; + m_prog.inst_at(id_inst1).u1.right_id = op2.id_first; + pushand(op2.id_first, id_inst1); + break; + case PLUS_LAZY: + op2 = popand('+'); + id_inst1 = m_prog.add_inst(OR); + id_inst2 = m_prog.add_inst(NOP); + m_prog.inst_at(op2.id_last).u2.next_id = id_inst1; + m_prog.inst_at(id_inst1).u2.left_id = op2.id_first; + m_prog.inst_at(id_inst1).u1.right_id = id_inst2; + pushand(op2.id_first, id_inst2); + break; + case QUEST: + op2 = popand('?'); + id_inst1 = m_prog.add_inst(OR); + id_inst2 = m_prog.add_inst(NOP); + m_prog.inst_at(id_inst1).u2.left_id = id_inst2; + m_prog.inst_at(id_inst1).u1.right_id = op2.id_first; + m_prog.inst_at(op2.id_last).u2.next_id = id_inst2; + pushand(id_inst1, id_inst2); + break; + case QUEST_LAZY: + op2 = popand('?'); + id_inst1 = m_prog.add_inst(OR); + id_inst2 = m_prog.add_inst(NOP); + m_prog.inst_at(id_inst1).u2.left_id = op2.id_first; + m_prog.inst_at(id_inst1).u1.right_id = id_inst2; + m_prog.inst_at(op2.id_last).u2.next_id = id_inst2; + pushand(id_inst1, id_inst2); + break; + } + } + } + + void Operator(int t) + { + if (t == RBRA && --nbra < 0) + //unmatched right paren + return; + if (t == LBRA) + { + nbra++; + if (lastwasand) + Operator(CAT); + } + else + evaluntil(t); + if (t != RBRA) + pushator(t); + lastwasand = ( + t == STAR || t == QUEST || t == PLUS || + t == STAR_LAZY || t == QUEST_LAZY || t == PLUS_LAZY || + t == RBRA); + } + + void Operand(int t) + { + if (lastwasand) + Operator(CAT); /* catenate is implicit */ + int inst_id = m_prog.add_inst(t); + if (t == CCLASS || t == NCCLASS) + m_prog.inst_at(inst_id).u1.cls_id = yyclass_id; + else if (t == CHAR || t==BOL || t==EOL) + m_prog.inst_at(inst_id).u1.c = yy; + pushand(inst_id, inst_id); + lastwasand = true; + } + + char32_t yy; + int yyclass_id; + + void expand_counted(const std::vector& in, std::vector& out) + { + std::vector lbra_stack; + int rep_start = -1; + + out.clear(); + for (int i = 0; i < in.size(); i++) + { + if (in[i].t != COUNTED && in[i].t != COUNTED_LAZY) + { + out.push_back(in[i]); + if (in[i].t == LBRA || in[i].t == LBRA_NC) + { + lbra_stack.push_back(i); + rep_start = -1; + } + else if (in[i].t == RBRA) + { + rep_start = lbra_stack[lbra_stack.size() - 1]; + lbra_stack.pop_back(); + } + else if ((in[i].t & 0300) != OPERATOR_MASK) + { + rep_start = i; + } + } + else + { + if (rep_start < 0) // broken regex + return; + + regex_parser::Item item = in[i]; + if (item.d.yycount.n <= 0) + { + // need to erase + for (int j = 0; j < i - rep_start; j++) + out.pop_back(); + } + else + { + // repeat + for (int j = 1; j < item.d.yycount.n; j++) + for (int k = rep_start; k < i; k++) + out.push_back(in[k]); + } + + // optional repeats + if (item.d.yycount.m >= 0) + { + for (int j = item.d.yycount.n; j < item.d.yycount.m; j++) + { + regex_parser::Item o_item; + o_item.t = LBRA_NC; + o_item.d.yy = 0; + out.push_back(o_item); + for (int k = rep_start; k < i; k++) + out.push_back(in[k]); + } + for (int j = item.d.yycount.n; j < item.d.yycount.m; j++) + { + regex_parser::Item o_item; + o_item.t = RBRA; + o_item.d.yy = 0; + out.push_back(o_item); + if (item.t == COUNTED) + { + o_item.t = QUEST; + out.push_back(o_item); + } + else + { + o_item.t = QUEST_LAZY; + out.push_back(o_item); + } + } + } + else // infinite repeat + { + regex_parser::Item o_item; + o_item.d.yy = 0; + + if (item.d.yycount.n > 0) // put '+' after last repetition + { + if (item.t == COUNTED) + { + o_item.t = PLUS; + out.push_back(o_item); + } + else + { + o_item.t = PLUS_LAZY; + out.push_back(o_item); + } + } + else // copy it once then put '*' + { + for (int k = rep_start; k < i; k++) + out.push_back(in[k]); + + if (item.t == COUNTED) + { + o_item.t = STAR; + out.push_back(o_item); + } + else + { + o_item.t = STAR_LAZY; + out.push_back(o_item); + } + } + } + } + } + } + + +public: + regex_compiler(const char32_t* pattern, int dot_type, reprog& prog) + : m_prog(prog), cursubid(0), pushsubid(0), lastwasand(false), nbra(0), + yyclass_id(0), yy(0) + { + // Parse + std::vector items; + { + regex_parser parser(pattern, dot_type, m_prog); + + // Expand counted repetitions + if (parser.m_has_counted) + expand_counted(parser.m_items, items); + else + items = parser.m_items; + } + + /* Start with a low priority operator to prime parser */ + pushator(START - 1); + + for (int i = 0; i < static_cast(items.size()); i++) + { + regex_parser::Item item = items[i]; + int token = item.t; + if (token == CCLASS || token == NCCLASS) + yyclass_id = item.d.yyclass_id; + else + yy = item.d.yy; + + if (token == LBRA) + { + ++cursubid; + pushsubid = cursubid; + } + else if (token == LBRA_NC) + { + pushsubid = 0; + token = LBRA; + } + + if ((token & 0300) == OPERATOR_MASK) + Operator(token); + else + Operand(token); + } + + /* Close with a low priority operator */ + evaluntil(START); + /* Force END */ + Operand(END); + evaluntil(START); + if (nbra) + ; // "unmatched left paren"; + /* points to first and only operand */ + m_prog.set_start_inst(andstack[andstack.size() - 1].id_first); + m_prog.optimize1(); + m_prog.optimize2(); + m_prog.set_groups_count(cursubid); + } +}; + +// Convert pattern into program +reprog reprog::create_from(const char32_t* pattern) +{ + reprog rtn; + regex_compiler compiler(pattern, ANY, rtn); // future feature: ANYNL + //rtn->print(); + return rtn; +} + +// +void reprog::optimize1() +{ + // Treat non-capturing LBRAs/RBRAs as NOOP + for (int i = 0; i < static_cast(_insts.size()); i++) + { + if (_insts[i].type == LBRA || _insts[i].type == RBRA) + { + if (_insts[i].u1.subid < 1) + { + _insts[i].type = NOP; + } + } + } + + // get rid of NOP chains + for (int i=0; i < insts_count(); i++) + { + if( _insts[i].type != NOP ) + { + { + int target_id = _insts[i].u2.next_id; + while(_insts[target_id].type == NOP) + target_id = _insts[target_id].u2.next_id; + _insts[i].u2.next_id = target_id; + } + if( _insts[i].type == OR ) + { + int target_id = _insts[i].u1.right_id; + while(_insts[target_id].type == NOP) + target_id = _insts[target_id].u2.next_id; + _insts[i].u1.right_id = target_id; + } + } + } + // skip NOPs from the beginning + { + int target_id = _startinst_id; + while( _insts[target_id].type == NOP) + target_id = _insts[target_id].u2.next_id; + _startinst_id = target_id; + } + // actually remove the no-ops + std::vector id_map(insts_count()); + int j = 0; // compact the ops (non no-ops) + for( int i = 0; i < insts_count(); i++) + { + id_map[i] = j; + if( _insts[i].type != NOP ) + { + _insts[j] = _insts[i]; + j++; + } + } + _insts.resize(j); + // fix up the ORs + for( int i=0; i < insts_count(); i++) + { + { + int target_id = _insts[i].u2.next_id; + _insts[i].u2.next_id = id_map[target_id]; + } + if( _insts[i].type == OR ) + { + int target_id = _insts[i].u1.right_id; + _insts[i].u1.right_id = id_map[target_id]; + } + } + // set the new start id + _startinst_id = id_map[_startinst_id]; +} + +// expand leading ORs to multiple startinst_ids +void reprog::optimize2() +{ + _startinst_ids.clear(); + std::vector stack; + stack.push_back(_startinst_id); + while(!stack.empty()) + { + int id = stack.back(); + stack.pop_back(); + const reinst& inst = _insts[id]; + if(inst.type == OR) + { + stack.push_back(inst.u2.left_id); + stack.push_back(inst.u1.right_id); + } + else + { + _startinst_ids.push_back(id); + } + } + _startinst_ids.push_back(-1); // terminator mark +} + +void reprog::print() +{ + printf("Instructions:\n"); + for(int i = 0; i < _insts.size(); i++) + { + const reinst& inst = _insts[i]; + printf("%d :", i); + switch (inst.type) + { + default: + printf("Unknown instruction: %d, nextid= %d", inst.type, inst.u2.next_id); + break; + case CHAR: + if( inst.u1.c <=32 || inst.u1.c >=127 ) + printf("CHAR, c = '0x%02x', nextid= %d", static_cast(inst.u1.c), inst.u2.next_id); + else + printf("CHAR, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id); + break; + case RBRA: + printf("RBRA, subid= %d, nextid= %d", inst.u1.subid, inst.u2.next_id); + break; + case LBRA: + printf("LBRA, subid= %d, nextid= %d", inst.u1.subid, inst.u2.next_id); + break; + case OR: + printf("OR, rightid=%d, leftid=%d, nextid=%d", inst.u1.right_id, inst.u2.left_id, inst.u2.next_id); + break; + case STAR: + printf("STAR, nextid= %d", inst.u2.next_id); + break; + case PLUS: + printf("PLUS, nextid= %d", inst.u2.next_id); + break; + case QUEST: + printf("QUEST, nextid= %d", inst.u2.next_id); + break; + case ANY: + printf("ANY, nextid= %d", inst.u2.next_id); + break; + case ANYNL: + printf("ANYNL, nextid= %d", inst.u2.next_id); + break; + case NOP: + printf("NOP, nextid= %d", inst.u2.next_id); + break; + case BOL: + printf("BOL, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id); + break; + case EOL: + printf("EOL, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id); + break; + case CCLASS: + printf("CCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id); + break; + case NCCLASS: + printf("NCCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id); + break; + case BOW: + printf("BOW, nextid= %d", inst.u2.next_id); + break; + case NBOW: + printf("NBOW, nextid= %d", inst.u2.next_id); + break; + case END: + printf("END"); + break; + } + printf("\n"); + } + + printf("startinst_id=%d\n", _startinst_id); + if( _startinst_ids.size() > 0 ) + { + printf("startinst_ids:"); + for (size_t i = 0; i < _startinst_ids.size(); i++) + printf(" %d", _startinst_ids[i]); + printf("\n"); + } + + int count = static_cast(_classes.size()); + printf("\nClasses %d\n",count); + for( int i = 0; i < count; i++ ) + { + const reclass& cls = _classes[i]; + int len = static_cast(cls.literals.size()); + printf("%2d: ", i); + for( int j=0; j < len; j += 2 ) + { + char32_t c1 = cls.literals[j]; + char32_t c2 = cls.literals[j+1]; + if( c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127 ) + printf("0x%02x-0x%02x",static_cast(c1),static_cast(c2)); + else + printf("%c-%c",static_cast(c1),static_cast(c2)); + if( (j+2) < len ) + printf(", "); + } + printf("\n"); + if( cls.builtins ) + { + int mask = cls.builtins; + printf(" builtins(x%02X):",static_cast(mask)); + if( mask & 1 ) + printf(" \\w"); + if( mask & 2 ) + printf(" \\s"); + if( mask & 4 ) + printf(" \\d"); + if( mask & 8 ) + printf(" \\W"); + if( mask & 16 ) + printf(" \\S"); + if( mask & 32 ) + printf(" \\D"); + } + printf("\n"); + } + if( _num_capturing_groups ) + printf("Number of capturing groups: %d\n", _num_capturing_groups); +} + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/regex/regcomp.h b/cpp/src/strings/regex/regcomp.h new file mode 100644 index 00000000000..4de3a96977c --- /dev/null +++ b/cpp/src/strings/regex/regcomp.h @@ -0,0 +1,137 @@ +/* +* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#pragma once +#include +#include + +namespace cudf +{ +namespace strings +{ +namespace detail +{ + +/** + * @brief Actions and Tokens (regex instruction types) + * + * ``` + * 02xx are operators, value == precedence + * 03xx are tokens, i.e. operands for operators + * ``` + */ +enum InstType +{ + CHAR = 0177, // Literal character + RBRA = 0201, // Right bracket, ) + LBRA = 0202, // Left bracket, ( + OR = 0204, // Alternation, | + ANY = 0300, // Any character except newline, . + ANYNL = 0301, // Any character including newline, . + BOL = 0303, // Beginning of line, ^ + EOL = 0304, // End of line, $ + CCLASS = 0305, // Character class, [] + NCCLASS = 0306, // Negated character class, [] + BOW = 0307, // Boundary of word, /b + NBOW = 0310, // Not boundary of word, /b + END = 0377 // Terminate: match found +}; + +/** + * @brief Class type for regex compiler instruction. + */ +struct reclass +{ + int32_t builtins; // bit mask identifying builtin classes + std::u32string literals; // ranges as pairs of utf-8 characters + reclass() : builtins(0) {} + reclass(int m) : builtins(m) {} +}; + +/** + * @brief Structure of an encoded regex instruction + */ +struct reinst +{ + int32_t type; /* operator type or instruction type */ + union { + int32_t cls_id; /* class pointer */ + char32_t c; /* character */ + int32_t subid; /* sub-expression id for RBRA and LBRA */ + int32_t right_id; /* right child of OR */ + } u1; + union { /* regexec relies on these two being in the same union */ + int32_t left_id; /* left child of OR */ + int32_t next_id; /* next instruction for CAT & LBRA */ + } u2; + int32_t reserved4; +}; + +/** + * @brief Regex program handles parsing a pattern in to individual set + * of chained instructions. + */ +class reprog +{ + public: + + reprog() = default; + reprog(const reprog&) = default; + reprog(reprog&&) = default; + ~reprog() = default; + reprog& operator=(const reprog&) = default; + reprog& operator=(reprog&&) = default; + + /** + * @brief Parses the given regex pattern and compiles + * into a list of chained instructions. + */ + static reprog create_from(const char32_t* pattern); + + int32_t add_inst(int32_t type); + int32_t add_inst(reinst inst); + int32_t add_class(reclass cls); + + void set_groups_count(int32_t groups); + int32_t groups_count() const; + + const reinst* insts_data() const; + int32_t insts_count() const; + reinst& inst_at(int32_t id); + + reclass& class_at(int32_t id); + int32_t classes_count() const; + + const int32_t* starts_data() const; + int32_t starts_count() const; + + void set_start_inst(int32_t id); + int32_t get_start_inst() const; + + void optimize1(); + void optimize2(); + void print(); // for debugging + +private: + std::vector _insts; + std::vector _classes; + int32_t _startinst_id; + std::vector _startinst_ids; // short-cut to speed-up ORs + int32_t _num_capturing_groups; +}; + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh new file mode 100644 index 00000000000..41a8db8d739 --- /dev/null +++ b/cpp/src/strings/regex/regex.cuh @@ -0,0 +1,184 @@ +/* +* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#pragma once + +#include +#include +#include + +namespace cudf +{ + +class string_view; + +namespace strings +{ +namespace detail +{ + +struct reljunk; +struct reinst; +class reprog; + +/** + * @brief Regex class stored on the device and executed by reprog_device. + * + * This class holds the unique data for any regex CCLASS instruction. + */ +class reclass_device +{ +public: + int32_t builtins{}; + int32_t count{}; + char32_t* literals{}; + + __device__ bool is_match(char32_t ch, const uint8_t* flags); +}; + +/** + * @brief Regex program of instructions/data for a specific regex pattern. + * + * Once create, this find/extract methods are used to evaluating the regex instructions + * against a single string. + */ +class reprog_device +{ +public: + reprog_device() = delete; + ~reprog_device() = default; + reprog_device(const reprog_device&) = default; + reprog_device(reprog_device&&) = default; + reprog_device& operator=(const reprog_device&) = default; + reprog_device& operator=(reprog_device&&) = default; + + /** + * @brief Create device program instance from a regex pattern. + * + * The number of strings is needed to compute the state data size required when evaluating the regex. + * + * @param pattern The regex pattern to compile. + * @param cp_flags The code-point lookup table for character types. + * @param strings_count Number of strings that will be evaluated. + * @param stream CUDA stream for asynchronous memory allocations. + * @return The program device object. + */ + static std::unique_ptr> + create(std::string const& pattern, const uint8_t* cp_flags, int32_t strings_count, cudaStream_t stream=0); + /** + * @brief Called automatically by the unique_ptr returned from create(). + */ + void destroy(); + + /** + * @brief Returns the number of regex instructions. + */ + int32_t insts_counts() const { return _insts_count; } + + /** + * @brief Returns the number of regex groups found in the expression. + */ + int32_t group_counts() const { return _num_capturing_groups; } + + /** + * @brief This sets up the memory used for keeping track of the regex progress. + * + * Call this for each string before calling find or extract. + */ + __device__ inline void set_stack_mem(u_char* s1, u_char* s2); + + /** + * @brief Returns the regex instruction object for a given index. + */ + __host__ __device__ inline reinst* get_inst(int32_t idx) const; + + /** + * @brief Returns the regex class object for a given index. + */ + __device__ inline reclass_device get_class(int32_t idx) const; + + /** + * @brief Returns the start-instruction-ids vector. + */ + __device__ inline int32_t* startinst_ids() const; + + /** + * @brief Does a find evaluation using the compiled expression on the given string. + * + * @param idx The string index used for mapping the state memory for this string in global memory (if necessary). + * @param d_str The string to search. + * @param[in,out] begin Position index to begin the search. If found, returns the position found in the string. + * @param[in,out] end Position index to end the search. If found, returns the last position matching in the string. + * @return Returns 0 if no match is found. + */ + __device__ inline int32_t find( int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end ); + + /** + * @brief Does an extract evaluation using the compiled expression on the given string. + * + * This will find a specific match within the string when more than match occurs. + * + * @param idx The string index used for mapping the state memory for this string in global memory (if necessary). + * @param d_str The string to search. + * @param[in,out] begin Position index to begin the search. If found, returns the position found in the string. + * @param[in,out] end Position index to end the search. If found, returns the last position matching in the string. + * @param column The specific instance to return if more than one match is found. + * @return Returns 0 if no match is found. + */ + __device__ inline int32_t extract( int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t column ); + +private: + int32_t _startinst_id, _num_capturing_groups; + int32_t _insts_count, _starts_count, _classes_count; + const uint8_t* _codepoint_flags{}; // table of character types + reinst* _insts{}; // array of regex instructions + int32_t* _startinst_ids{}; // array of start instruction ids + reclass_device* _classes{}; // array of regex classes + void* _relists_mem{}; // runtime relist memory for regexec + u_char* _stack_mem1{}; // memory for relist object 1 + u_char* _stack_mem2{}; // memory for relist object 2 + + /** + * @brief Executes the regex pattern on the given string. + */ + __device__ inline int32_t regexec( string_view const& d_str, reljunk& jnk, int32_t& begin, int32_t& end, int32_t groupid=0 ); + + /** + * @brief Utility wrapper to setup state memory structures for calling regexec + */ + __device__ inline int32_t call_regexec( int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t groupid=0 ); + + reprog_device(reprog&); // must use create() +}; + + +// 10128 ≈ 1000 instructions +// Formula is based on relist::data_size_for() calculaton; +// Stack ≈ (8+2)*x + (x/8) = 10.125x < 11x where x is number of instructions +constexpr int32_t MAX_STACK_INSTS = 1000; + +constexpr int32_t RX_STACK_SMALL = 112; +constexpr int32_t RX_STACK_MEDIUM = 1104; +constexpr int32_t RX_STACK_LARGE = 10128; + +constexpr int32_t RX_SMALL_INSTS = (RX_STACK_SMALL/11); +constexpr int32_t RX_MEDIUM_INSTS = (RX_STACK_MEDIUM/11); +constexpr int32_t RX_LARGE_INSTS = (RX_STACK_LARGE/11); + +} // namespace detail +} // namespace strings +} // namespace cudf + +#include "./regex.inl" diff --git a/cpp/src/strings/regex/regex.inl b/cpp/src/strings/regex/regex.inl new file mode 100644 index 00000000000..8dce090f2a5 --- /dev/null +++ b/cpp/src/strings/regex/regex.inl @@ -0,0 +1,462 @@ +/* +* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace cudf +{ +namespace strings +{ +namespace detail +{ + +/** + * @brief This holds the state information when evaluating a string + * against a regex pattern. + * + * There are 2 instances of this per string managed in the reljunk class. + * As each regex instruction is evaluated for a string, the result is + * reflected here. The regexec function updates and manages this state data. + */ +struct alignas(8) relist +{ + int16_t size; + int16_t listsize; + int32_t reserved; + int2* ranges; // pair per instruction + int16_t* inst_ids; // one per instruction + u_char* mask; // bit per instruction + + __host__ __device__ inline static int32_t data_size_for(int32_t insts) + { + return ((sizeof(ranges[0])+sizeof(inst_ids[0]))*insts) + ((insts+7)/8); + } + + __host__ __device__ inline static int32_t alloc_size(int32_t insts) + { + int32_t size = sizeof(relist); + size += data_size_for(insts); + size = ((size+7)/8)*8; // align it too + return size; + } + + __host__ __device__ inline relist() {} + + __host__ __device__ inline void set_data(int16_t insts, u_char* data=nullptr) + { + listsize = insts; + u_char* ptr = (u_char*)data; + if( ptr==nullptr ) + ptr = (reinterpret_cast(this)) + sizeof(relist); + ranges = reinterpret_cast(ptr); + ptr += listsize * sizeof(ranges[0]); + inst_ids = reinterpret_cast(ptr); + ptr += listsize * sizeof(inst_ids[0]); + mask = ptr; + reset(); + } + + __host__ __device__ inline void reset() + { + memset(mask, 0, (listsize+7)/8); + size = 0; + } + + __device__ inline bool activate(int32_t i, int32_t begin, int32_t end) + { + if(readMask(i)) + return false; + writeMask(true, i); + inst_ids[size] = static_cast(i); + ranges[size] = int2{begin,end}; + ++size; + return true; + } + + __device__ inline void writeMask(bool v, int32_t pos) + { + u_char uc = 1 << (pos & 7); + if (v) + mask[pos >> 3] |= uc; + else + mask[pos >> 3] &= ~uc; + } + + __device__ inline bool readMask(int32_t pos) + { + u_char uc = mask[pos >> 3]; + return static_cast((uc >> (pos & 7)) & 1); + } +}; + +/** + * @brief This manages the two relist instances required by the regexec function. + */ +struct reljunk +{ + relist* list1; + relist* list2; + int32_t starttype; + char32_t startchar; +}; + +__device__ inline void swaplist(relist*& l1, relist*& l2) +{ + relist* tmp = l1; + l1 = l2; + l2 = tmp; +} + +/** + * @brief Utility to check a specific character against this class instance. + * + * @param ch A 4-byte UTF-8 character. + * @param codepoint_flags Used for mapping a character to type for builtin classes. + * @return true if the character matches + */ +__device__ inline bool reclass_device::is_match(char32_t ch, const uint8_t* codepoint_flags) +{ + if( thrust::any_of( thrust::seq, thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), + [ch, this] __device__ (int i) { return ((ch >= literals[i*2]) && (ch <= literals[(i*2)+1])); }) ) + return true; + if( !builtins ) + return false; + uint32_t codept = utf8_to_codepoint(ch); + if( codept > 0x00FFFF ) + return false; + int8_t fl = codepoint_flags[codept]; + if( (builtins & 1) && ((ch=='_') || IS_ALPHANUM(fl)) ) // \w + return true; + if( (builtins & 2) && IS_SPACE(fl) ) // \s + return true; + if( (builtins & 4) && IS_DIGIT(fl) ) // \d + return true; + if( (builtins & 8) && ((ch != '\n') && (ch != '_') && !IS_ALPHANUM(fl)) ) // \W + return true; + if( (builtins & 16) && !IS_SPACE(fl) ) // \S + return true; + if( (builtins & 32) && ((ch != '\n') && !IS_DIGIT(fl)) ) // \D + return true; + // + return false; +} + +/** + * @brief Set the device data to be used for holding the state data of a string. + * + * With one thread per string, the stack is used to maintain state when evaluating the string. + * With large regex patterns, the normal stack is not always practical. + * This mechanism allows an alternate buffer of device memory to be used in place of the stack + * for the state data. + * + * Two distinct buffers are required for the state data. + */ +__device__ inline void reprog_device::set_stack_mem(u_char* s1, u_char* s2) +{ + _stack_mem1 = s1; + _stack_mem2 = s2; +} + +__host__ __device__ inline reinst* reprog_device::get_inst(int32_t idx) const +{ + assert( (idx >= 0) && (idx < _insts_count) ); + return _insts + idx; +} + +__device__ inline reclass_device reprog_device::get_class(int32_t idx) const +{ + assert( (idx >= 0) && (idx < _classes_count) ); + return _classes[idx]; +} + +__device__ inline int32_t* reprog_device::startinst_ids() const +{ + return _startinst_ids; +} + +/** + * @brief Evaluate a specific string against regex pattern compiled to this instance. + * + * This is the main function for executing the regex against an individual string. + * + * @param dstr String used for matching. + * @param jnk State data object for this string. + * @param[in,out] begin Character position to start evaluation. On return, it is the position of the match. + * @param[in,out] end Character position to stop evaluation. On return, it is the end of the matched substring. + * @param group_id Index of the group to match in a multi-group regex pattern. + * @return >0 if match found + */ +__device__ inline int32_t reprog_device::regexec(string_view const& dstr, reljunk &jnk, int32_t& begin, int32_t& end, int32_t group_id) +{ + int32_t match = 0; + auto checkstart = jnk.starttype; + auto txtlen = dstr.length(); + auto pos = begin; + auto eos = end; + char32_t c = 0; + string_view::const_iterator itr = string_view::const_iterator(dstr,pos); + + jnk.list1->reset(); + do + { + /* fast check for first char */ + if (checkstart) + { + switch (jnk.starttype) + { + case CHAR: + { + auto fidx = dstr.find(static_cast(jnk.startchar),pos); + if( fidx < 0 ) + return match; + pos = fidx; + break; + } + case BOL: + { + if( pos==0 ) + break; + if( jnk.startchar != '^' ) + return match; + --pos; + int fidx = dstr.find(static_cast('\n'),pos); + if( fidx < 0 ) + return match; // update begin/end values? + pos = fidx + 1; + break; + } + } + itr = string_view::const_iterator(dstr,pos); + } + + if ( ((eos < 0) || (pos < eos)) && match == 0) + { + //jnk.list1->activate(startinst_id, pos, 0); + int32_t i = 0; + auto ids = startinst_ids(); + while( ids[i] >=0 ) + jnk.list1->activate(ids[i++], (group_id==0 ? pos:-1), -1); + } + + c = static_cast(pos >= txtlen ? 0 : *itr); + + // expand LBRA, RBRA, BOL, EOL, BOW, NBOW, and OR + bool expanded = false; + do + { + jnk.list2->reset(); + expanded = false; + + for( int16_t i = 0; i < jnk.list1->size; i++) + { + int32_t inst_id = static_cast(jnk.list1->inst_ids[i]); + int2& range = jnk.list1->ranges[i]; + const reinst* inst = get_inst(inst_id); + int32_t id_activate = -1; + + switch(inst->type) + { + case CHAR: + case ANY: + case ANYNL: + case CCLASS: + case NCCLASS: + case END: + id_activate = inst_id; + break; + case LBRA: + if(inst->u1.subid == group_id) + range.x = pos; + id_activate = inst->u2.next_id; + expanded = true; + break; + case RBRA: + if(inst->u1.subid == group_id) + range.y = pos; + id_activate = inst->u2.next_id; + expanded = true; + break; + case BOL: + if( (pos==0) || ((inst->u1.c=='^') && (dstr[pos-1]==static_cast('\n'))) ) + { + id_activate = inst->u2.next_id; + expanded = true; + } + break; + case EOL: + if( (c==0) || (inst->u1.c == '$' && c == '\n')) + { + id_activate = inst->u2.next_id; + expanded = true; + } + break; + case BOW: + { + auto codept = utf8_to_codepoint(c); + char32_t last_c = static_cast(pos ? dstr[pos-1] : 0); + auto last_codept = utf8_to_codepoint(last_c); + bool cur_alphaNumeric = (codept < 0x010000) && IS_ALPHANUM(_codepoint_flags[codept]); + bool last_alphaNumeric = (last_codept < 0x010000) && IS_ALPHANUM(_codepoint_flags[last_codept]); + if( cur_alphaNumeric != last_alphaNumeric ) + { + id_activate = inst->u2.next_id; + expanded = true; + } + break; + } + case NBOW: + { + auto codept = utf8_to_codepoint(c); + char32_t last_c = static_cast(pos ? dstr[pos-1] : 0); + auto last_codept = utf8_to_codepoint(last_c); + bool cur_alphaNumeric = (codept < 0x010000) && IS_ALPHANUM(_codepoint_flags[codept]); + bool last_alphaNumeric = (last_codept < 0x010000) && IS_ALPHANUM(_codepoint_flags[last_codept]); + if( cur_alphaNumeric == last_alphaNumeric ) + { + id_activate = inst->u2.next_id; + expanded = true; + } + break; + } + case OR: + jnk.list2->activate(inst->u1.right_id, range.x, range.y); + id_activate = inst->u2.left_id; + expanded = true; + break; + } + if (id_activate >= 0) + jnk.list2->activate(id_activate, range.x, range.y); + + } + swaplist(jnk.list1, jnk.list2); + + } while (expanded); + + // execute + jnk.list2->reset(); + for (int16_t i = 0; i < jnk.list1->size; i++) + { + int32_t inst_id = static_cast(jnk.list1->inst_ids[i]); + int2& range = jnk.list1->ranges[i]; + const reinst* inst = get_inst(inst_id); + int32_t id_activate = -1; + + switch(inst->type) + { + case CHAR: + if(inst->u1.c == c) + id_activate = inst->u2.next_id; + break; + case ANY: + if(c != '\n') + id_activate = inst->u2.next_id; + break; + case ANYNL: + id_activate = inst->u2.next_id; + break; + case CCLASS: + { + reclass_device cls = get_class(inst->u1.cls_id); + if( cls.is_match(c,_codepoint_flags) ) + id_activate = inst->u2.next_id; + break; + } + case NCCLASS: + { + reclass_device cls = get_class(inst->u1.cls_id); + if( !cls.is_match(c,_codepoint_flags) ) + id_activate = inst->u2.next_id; + break; + } + case END: + match = 1; + begin = range.x; + end = group_id==0? pos : range.y; + goto BreakFor; + } + if (id_activate >= 0) + jnk.list2->activate(id_activate, range.x, range.y); + } + + BreakFor: + ++pos; + ++itr; + swaplist(jnk.list1, jnk.list2); + checkstart = jnk.list1->size > 0 ? 0 : 1; + } + while (c && (jnk.list1->size>0 || match == 0)); + return match; +} + +__device__ inline int32_t reprog_device::find( int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end ) +{ + int32_t rtn = call_regexec(idx,dstr,begin,end); + if( rtn <=0 ) + begin = end = -1; + return rtn; +} + +__device__ inline int32_t reprog_device::extract( int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id ) +{ + end = begin + 1; + return call_regexec(idx,dstr,begin,end,group_id+1); +} + +__device__ inline int32_t reprog_device::call_regexec( int32_t idx, string_view const& dstr, int32_t& begin, int32_t& end, int32_t group_id ) +{ + reljunk jnk; + jnk.starttype = 0; + jnk.startchar = 0; + int type = get_inst(_startinst_id)->type; + if( type == CHAR || type == BOL ) + { + jnk.starttype = type; + jnk.startchar = get_inst(_startinst_id)->u1.c; + } + + if( _relists_mem==0 ) + { + relist relist1; + relist relist2; + jnk.list1 = &relist1; + jnk.list2 = &relist2; + jnk.list1->set_data(static_cast(_insts_count),_stack_mem1); + jnk.list2->set_data(static_cast(_insts_count),_stack_mem2); + return regexec(dstr,jnk,begin,end,group_id); + } + + auto relists_size = relist::alloc_size(_insts_count); + u_char* drel = reinterpret_cast(_relists_mem); // beginning of relist buffer; + drel += (idx * relists_size * 2); // two relist ptrs in reljunk: + jnk.list1 = reinterpret_cast(drel); // - first one + jnk.list2 = reinterpret_cast(drel + relists_size); // - second one + jnk.list1->set_data(static_cast(_insts_count)); // essentially this is + jnk.list2->set_data(static_cast(_insts_count)); // substitute ctor call + return regexec(dstr,jnk,begin,end,group_id); +} + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/regex/regexec.cu b/cpp/src/strings/regex/regexec.cu new file mode 100644 index 00000000000..0f29b48c155 --- /dev/null +++ b/cpp/src/strings/regex/regexec.cu @@ -0,0 +1,182 @@ +/* +* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include +#include "./regex.cuh" +#include "./regcomp.h" + +#include +#include +#include + +namespace cudf +{ +namespace strings +{ +namespace detail +{ +namespace +{ + +/** + * @brief Converts UTF-8 string into fixed-width 32-bit character vector. + * + * No character conversion occurs. + * Each UTF-8 character is promoted into a 32-bit value. + * The last entry in the returned vector will be a 0 value. + * The fixed-width vector makes it easier to compile and faster to execute. + * + * @param pattern Regular expression encoded with UTF-8. + * @return Fixed-width 32-bit character vector. + */ +std::vector string_to_char32_vector( std::string const& pattern ) +{ + size_type size = static_cast(pattern.size()); + size_type count = characters_in_string(pattern.c_str(),size); + std::vector result(count+1); + char32_t* output_ptr = result.data(); + const char* input_ptr = pattern.data(); + for( size_type idx=0; idx < size; ++idx ) + { + char_utf8 output_character = 0; + size_type ch_width = to_char_utf8(input_ptr,output_character); + input_ptr += ch_width; + idx += ch_width - 1; + *output_ptr++ = output_character; + } + result[count] = 0; // last entry set to 0 + return result; +} + +} + +// Copy reprog primitive values +reprog_device::reprog_device(reprog& prog) + : _startinst_id{prog.get_start_inst()}, + _num_capturing_groups{prog.groups_count()}, + _insts_count{prog.insts_count()}, + _starts_count{prog.starts_count()}, + _classes_count{prog.classes_count()}, + _relists_mem{nullptr}, + _stack_mem1{nullptr}, + _stack_mem2{nullptr} +{} + +// Create instance of the reprog that can be passed into a device kernel +std::unique_ptr> + reprog_device::create(std::string const& pattern, const uint8_t* codepoint_flags, size_type strings_count, cudaStream_t stream ) +{ + std::vector pattern32 = string_to_char32_vector(pattern); + // compile pattern into host object + reprog h_prog = reprog::create_from(pattern32.data()); + // compute size to hold all the member data + auto insts_count = h_prog.insts_count(); + auto classes_count = h_prog.classes_count(); + auto starts_count = h_prog.starts_count(); + auto insts_size = insts_count * sizeof(_insts[0]); + auto startids_size = starts_count * sizeof(_startinst_ids[0]); + auto classes_size = classes_count * sizeof(_classes[0]); + for( int32_t idx=0; idx < classes_count; ++idx ) + classes_size += static_cast((h_prog.class_at(idx).literals.size())*sizeof(char32_t)); + size_t memsize = insts_size + startids_size + classes_size; + size_t rlm_size = 0; + // check memory size needed for executing regex + if( insts_count > MAX_STACK_INSTS ) + { + auto relist_alloc_size = relist::alloc_size(insts_count); + size_t rlm_size = relist_alloc_size*2L*strings_count; // reljunk has 2 relist ptrs + size_t freeSize = 0; + size_t totalSize = 0; + rmmGetInfo(&freeSize,&totalSize,stream); + if( rlm_size + memsize > freeSize ) // do not allocate more than we have + { // otherwise, this is unrecoverable + std::ostringstream message; + message << "cuDF failure at: " __FILE__ ":" << __LINE__ << ": "; + message << "number of instructions (" << insts_count << ") "; + message << "and number of strings (" << strings_count << ") "; + message << "exceeds available memory"; + throw cudf::logic_error(message.str()); + } + } + + // allocate memory to store prog data + std::vector h_buffer(memsize); + u_char* h_ptr = h_buffer.data(); // running pointer + u_char* d_buffer = 0; + RMM_TRY(RMM_ALLOC(&d_buffer,memsize,stream)); + u_char* d_ptr = d_buffer; // running device pointer + // put everything into a flat host buffer first + reprog_device* d_prog = new reprog_device(h_prog); + // copy the instructions array first (fixed-size structs) + reinst* insts = reinterpret_cast(h_ptr); + memcpy( insts, h_prog.insts_data(), insts_size); + h_ptr += insts_size; // next section + d_prog->_insts = reinterpret_cast(d_ptr); + d_ptr += insts_size; + // copy the startinst_ids next (ints) + int32_t* startinst_ids = reinterpret_cast(h_ptr); + memcpy( startinst_ids, h_prog.starts_data(), startids_size ); + h_ptr += startids_size; // next section + d_prog->_startinst_ids = reinterpret_cast(d_ptr); + d_ptr += startids_size; + // copy classes into flat memory: [class1,class2,...][char32 arrays] + reclass_device* classes = reinterpret_cast(h_ptr); + d_prog->_classes = reinterpret_cast(d_ptr); + // get pointer to the end to handle variable length data + u_char* h_end = h_ptr + (classes_count * sizeof(reclass_device)); + u_char* d_end = d_ptr + (classes_count * sizeof(reclass_device)); + // place each class and append the variable length data + for( int32_t idx=0; idx < classes_count; ++idx ) + { + reclass& h_class = h_prog.class_at(idx); + reclass_device d_class; + d_class.builtins = h_class.builtins; + d_class.count = h_class.literals.size(); + d_class.literals = reinterpret_cast(d_end); + memcpy( classes++, &d_class, sizeof(d_class) ); + memcpy( h_end, h_class.literals.c_str(), h_class.literals.size()*sizeof(char32_t) ); + h_end += h_class.literals.size()*sizeof(char32_t); + d_end += h_class.literals.size()*sizeof(char32_t); + } + // initialize the rest of the elements + d_prog->_insts_count = insts_count; + d_prog->_starts_count = starts_count; + d_prog->_classes_count = classes_count; + d_prog->_codepoint_flags = codepoint_flags; + // allocate execute memory if needed + if( rlm_size > 0 ) + { + RMM_TRY(RMM_ALLOC(&(d_prog->_relists_mem),rlm_size,stream)); + } + + // copy flat prog to device memory + CUDA_TRY(cudaMemcpy(d_buffer,h_buffer.data(),memsize,cudaMemcpyHostToDevice)); + // + auto deleter = [](reprog_device*t) {t->destroy();}; + return std::unique_ptr>(d_prog,deleter); +} + +void reprog_device::destroy() +{ + if( _relists_mem ) + RMM_FREE(_relists_mem,0); + RMM_FREE(_insts,0); + delete this; +} + +} // namespace detail +} // namespace strings +} // namespace cudf diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 67f8f3f426b..0c458c1fb83 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -751,6 +751,7 @@ set(STRINGS_TEST_SRC "${CMAKE_CURRENT_SOURCE_DIR}/strings/chars_types_tests.cu" "${CMAKE_CURRENT_SOURCE_DIR}/strings/combine_tests.cu" "${CMAKE_CURRENT_SOURCE_DIR}/strings/concatenate_tests.cu" + "${CMAKE_CURRENT_SOURCE_DIR}/strings/contains_tests.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/strings/datetime_tests.cu" "${CMAKE_CURRENT_SOURCE_DIR}/strings/fill_tests.cu" "${CMAKE_CURRENT_SOURCE_DIR}/strings/find_tests.cu" diff --git a/cpp/tests/strings/contains_tests.cpp b/cpp/tests/strings/contains_tests.cpp new file mode 100644 index 00000000000..45e97574d1d --- /dev/null +++ b/cpp/tests/strings/contains_tests.cpp @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + + +struct StringsContainsTests : public cudf::test::BaseFixture {}; + + +TEST_F(StringsContainsTests, ContainsTest) +{ + std::vector h_strings{ + "5", + "hej", + "\t \n", + "12345", + "\\", + "d", + "c:\\Tools", + "+27", + "1c2", + "1C2", + "0:00:0", + "0:0:00", + "00:0:0", + "00:00:0", + "00:0:00", + "0:00:00", + "00:00:00", + "Hello world !", + "Hello world! ", + "Hello worldcup !", + "0123456789", + "1C2", + "Xaa", + "abcdefghxxx", + "ABCDEFGH", + "abcdefgh", + "abc def", + "abc\ndef", + "aa\r\nbb\r\ncc\r\n\r\n", + "abcabc", + nullptr, "" }; + + cudf::test::strings_column_wrapper strings( h_strings.begin(), h_strings.end(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + auto strings_view = cudf::strings_column_view(strings); + + std::vector patterns{ + "\\d", + "\\w+", + "\\s", + "\\S", + "^.*\\\\.*$", + "[1-5]+", + "[a-h]+", + "[A-H]+", + "\n", + "b.\\s*\n", + ".*c", + "\\d\\d:\\d\\d:\\d\\d", + "\\d\\d?:\\d\\d?:\\d\\d?", + "[Hh]ello [Ww]orld", + "\\bworld\\b" }; + + std::vector h_expecteds{ // strings.size x patterns.size + true, false, false, true, false, false, false, true, true, true, true, true, true, true, true, true, true, false, false, false, true, true, false, false, false, false, false, false, false, false, false, false, + true, true, false, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, + false, false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true, true, true, false, false, false, false, false, false, true, true, true, false, false, false, + true, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, + false, false, false, false, true, false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, + true, false, false, true, false, false, false, true, true, true, false, false, false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false, false, false, false, false, + false, true, false, false, false, true, true, false, true, false, false, false, false, false, false, false, false, true, true, true, false, false, true, true, false, true, true, true, true, true, false, false, + false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false, false, true, true, true, false, true, false, false, true, false, false, false, false, false, false, false, + false, false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true, true, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true, true, false, false, false, + false, false, false, false, false, false, true, false, true, false, false, false, false, false, false, false, false, false, false, true, false, false, false, true, false, true, true, true, true, true, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false + }; + + for( int idx=0; idx < static_cast(patterns.size()); ++idx ) + { + std::string ptn = patterns[idx]; + auto results = cudf::strings::contains_re(strings_view,ptn); + cudf::experimental::bool8* h_expected = h_expecteds.data() + (idx * h_strings.size()); + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } +} + + +TEST_F(StringsContainsTests, MatchesTest) +{ + std::vector h_strings{ + "The quick brown @fox jumps", "ovér the", "lazy @dog", "1234", "00:0:00", nullptr, "" }; + cudf::test::strings_column_wrapper strings( h_strings.begin(), h_strings.end(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + + auto strings_view = cudf::strings_column_view(strings); + { + auto results = cudf::strings::matches_re(strings_view,"lazy"); + cudf::experimental::bool8 h_expected[] = {false,false,true,false,false,false,false}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } + { + auto results = cudf::strings::matches_re(strings_view,"\\d+"); + cudf::experimental::bool8 h_expected[] = {false,false,false,true,true,false,false}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } + { + auto results = cudf::strings::matches_re(strings_view,"@\\w+"); + cudf::experimental::bool8 h_expected[] = {false,false,false,false,false,false,false}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } +} + +TEST_F(StringsContainsTests, CountTest) +{ + std::vector h_strings{ + "The quick brown @fox jumps ovér the", "lazy @dog", "1:2:3:4", "00:0:00", nullptr, "" }; + cudf::test::strings_column_wrapper strings( h_strings.begin(), h_strings.end(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + + auto strings_view = cudf::strings_column_view(strings); + { + auto results = cudf::strings::count_re(strings_view,"[tT]he"); + int32_t h_expected[] = {2,0,0,0,0,0}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } + { + auto results = cudf::strings::count_re(strings_view,"@\\w+"); + int32_t h_expected[] = {1,1,0,0,0,0}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } + { + auto results = cudf::strings::count_re(strings_view,"\\d+:\\d+"); + int32_t h_expected[] = {0,0,2,1,0,0}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } +} + +TEST_F(StringsContainsTests, MediumRegex) +{ + // This results in 95 regex instructions and falls in the 'medium' range. + std::string medium_regex = "hello @abc @def world The quick brown @fox jumps over the lazy @dog hello http://www.world.com"; + + std::vector h_strings{ + "hello @abc @def world The quick brown @fox jumps over the lazy @dog hello http://www.world.com thats all", + "12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890", + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + }; + cudf::test::strings_column_wrapper strings( h_strings.begin(), h_strings.end(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + + auto strings_view = cudf::strings_column_view(strings); + { + auto results = cudf::strings::contains_re(strings_view, medium_regex); + cudf::experimental::bool8 h_expected[] = {true, false, false}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } + { + auto results = cudf::strings::matches_re(strings_view, medium_regex); + cudf::experimental::bool8 h_expected[] = {true, false, false}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } + { + auto results = cudf::strings::count_re(strings_view, medium_regex); + int32_t h_expected[] = {1,0,0}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } +} + +TEST_F(StringsContainsTests, LargeRegex) +{ + // This results in 115 regex instructions and falls in the 'large' range. + std::string large_regex = "hello @abc @def world The quick brown @fox jumps over the lazy @dog hello http://www.world.com I'm here @home zzzz"; + + std::vector h_strings{ + "hello @abc @def world The quick brown @fox jumps over the lazy @dog hello http://www.world.com I'm here @home zzzz", + "12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890", + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + }; + cudf::test::strings_column_wrapper strings( h_strings.begin(), h_strings.end(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + + auto strings_view = cudf::strings_column_view(strings); + { + auto results = cudf::strings::contains_re(strings_view, large_regex); + cudf::experimental::bool8 h_expected[] = {true, false, false}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } + { + auto results = cudf::strings::matches_re(strings_view, large_regex); + cudf::experimental::bool8 h_expected[] = {true, false, false}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } + { + auto results = cudf::strings::count_re(strings_view, large_regex); + int32_t h_expected[] = {1,0,0}; + cudf::test::fixed_width_column_wrapper expected( h_expected, h_expected+h_strings.size(), + thrust::make_transform_iterator( h_strings.begin(), [] (auto str) { return str!=nullptr; })); + cudf::test::expect_columns_equal(*results,expected); + } +}