diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index 7be88d01387..244cec1d780 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -830,6 +830,7 @@ class regex_compiler { m_prog.set_start_inst(andstack[andstack.size() - 1].id_first); m_prog.optimize1(); m_prog.optimize2(); + m_prog.check_for_errors(); m_prog.set_groups_count(cursubid); } }; @@ -926,6 +927,68 @@ void reprog::optimize2() _startinst_ids.push_back(-1); // terminator mark } +/** + * @brief Check a specific instruction for errors. + * + * Currently this is checking for an infinite-loop condition as documented in this issue: + * https://github.com/rapidsai/cudf/issues/10006 + * + * Example instructions list created from pattern `(A?)+` + * ``` + * 0: CHAR c='A', next=2 + * 1: OR right=0, left=2, next=2 + * 2: RBRA id=1, next=4 + * 3: LBRA id=1, next=1 + * 4: OR right=3, left=5, next=5 + * 5: END + * ``` + * + * Following the example above, the instruction at `id==1` (OR) + * is being checked. If the instruction path returns to `id==1` + * without including the `0==CHAR` or `5==END` as in this example, + * then this would cause the runtime to go into an infinite-loop. + * + * It appears this example pattern is not valid. But Python interprets + * its behavior similarly to pattern `(A*)`. Handling this in the same + * way does not look feasible with the current implementation. + * + * @throw cudf::logic_error if instruction logic error is found + * + * @param id Instruction to check if repeated. + * @param next_id Next instruction to process. + */ +void reprog::check_for_errors(int32_t id, int32_t next_id) +{ + auto inst = inst_at(next_id); + while (inst.type == LBRA || inst.type == RBRA) { + next_id = inst.u2.next_id; + inst = inst_at(next_id); + } + if (inst.type == OR) { + CUDF_EXPECTS(next_id != id, "Unsupported regex pattern"); + check_for_errors(id, inst.u2.left_id); + check_for_errors(id, inst.u1.right_id); + } +} + +/** + * @brief Check regex instruction set for any errors. + * + * Currently, this checks for OR instructions that eventually point back to themselves with only + * intervening capture group instructions between causing an infinite-loop during runtime + * evaluation. + */ +void reprog::check_for_errors() +{ + for (auto id = 0; id < insts_count(); ++id) { + auto const inst = inst_at(id); + if (inst.type == OR) { + check_for_errors(id, inst.u2.left_id); + check_for_errors(id, inst.u1.right_id); + } + } +} + #ifndef NDEBUG void reprog::print(regex_flags const flags) { @@ -933,83 +996,81 @@ void reprog::print(regex_flags const flags) printf("Instructions:\n"); for (std::size_t i = 0; i < _insts.size(); i++) { const reinst& inst = _insts[i]; - printf("%zu :", i); + printf("%3zu: ", i); switch (inst.type) { - default: printf("Unknown instruction: %d, nextid= %d", inst.type, inst.u2.next_id); break; + default: printf("Unknown instruction: %d, next=%d", inst.type, inst.u2.next_id); break; case CHAR: - if (inst.u1.c <= 32 || inst.u1.c >= 127) - printf( - "CHAR, c = '0x%02x', nextid= %d", static_cast(inst.u1.c), inst.u2.next_id); - else - printf("CHAR, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id); + if (inst.u1.c <= 32 || inst.u1.c >= 127) { + printf(" CHAR c='0x%02x', next=%d", static_cast(inst.u1.c), inst.u2.next_id); + } else { + printf(" CHAR c='%c', next=%d", inst.u1.c, inst.u2.next_id); + } break; - case RBRA: printf("RBRA, subid= %d, nextid= %d", inst.u1.subid, inst.u2.next_id); break; - case LBRA: printf("LBRA, subid= %d, nextid= %d", inst.u1.subid, inst.u2.next_id); break; + case RBRA: printf(" RBRA id=%d, next=%d", inst.u1.subid, inst.u2.next_id); break; + case LBRA: printf(" LBRA id=%d, next=%d", inst.u1.subid, inst.u2.next_id); break; case OR: - printf("OR, rightid=%d, leftid=%d, nextid=%d", - inst.u1.right_id, - inst.u2.left_id, - inst.u2.next_id); + printf( + " OR right=%d, left=%d, next=%d", inst.u1.right_id, inst.u2.left_id, inst.u2.next_id); break; - case STAR: printf("STAR, nextid= %d", inst.u2.next_id); break; - case PLUS: printf("PLUS, nextid= %d", inst.u2.next_id); break; - case QUEST: printf("QUEST, nextid= %d", inst.u2.next_id); break; - case ANY: printf("ANY, nextid= %d", inst.u2.next_id); break; - case ANYNL: printf("ANYNL, nextid= %d", inst.u2.next_id); break; - case NOP: printf("NOP, nextid= %d", inst.u2.next_id); break; + case STAR: printf(" STAR next=%d", inst.u2.next_id); break; + case PLUS: printf(" PLUS next=%d", inst.u2.next_id); break; + case QUEST: printf(" QUEST next=%d", inst.u2.next_id); break; + case ANY: printf(" ANY next=%d", inst.u2.next_id); break; + case ANYNL: printf(" ANYNL next=%d", inst.u2.next_id); break; + case NOP: printf(" NOP next=%d", inst.u2.next_id); break; case BOL: { - printf("BOL, c = "); + printf(" BOL c="); if (inst.u1.c == '\n') { printf("'\\n'"); } else { printf("'%c'", inst.u1.c); } - printf(", nextid= %d", inst.u2.next_id); + printf(", next=%d", inst.u2.next_id); break; } case EOL: { - printf("EOL, c = "); + printf(" EOL c="); if (inst.u1.c == '\n') { printf("'\\n'"); } else { printf("'%c'", inst.u1.c); } - printf(", nextid= %d", inst.u2.next_id); + printf(", next=%d", inst.u2.next_id); break; } - case CCLASS: printf("CCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id); break; - case NCCLASS: - printf("NCCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id); - break; - case BOW: printf("BOW, nextid= %d", inst.u2.next_id); break; - case NBOW: printf("NBOW, nextid= %d", inst.u2.next_id); break; - case END: printf("END"); break; + case CCLASS: printf(" CCLASS cls=%d , next=%d", inst.u1.cls_id, inst.u2.next_id); break; + case NCCLASS: printf("NCCLASS cls=%d, next=%d", inst.u1.cls_id, inst.u2.next_id); break; + case BOW: printf(" BOW next=%d", inst.u2.next_id); break; + case NBOW: printf(" NBOW next=%d", inst.u2.next_id); break; + case END: printf(" END"); break; } printf("\n"); } printf("startinst_id=%d\n", _startinst_id); if (_startinst_ids.size() > 0) { - printf("startinst_ids:"); - for (size_t i = 0; i < _startinst_ids.size(); i++) + printf("startinst_ids: ["); + for (size_t i = 0; i < _startinst_ids.size(); i++) { printf(" %d", _startinst_ids[i]); - printf("\n"); + } + printf("]\n"); } int count = static_cast(_classes.size()); printf("\nClasses %d\n", count); for (int i = 0; i < count; i++) { const reclass& cls = _classes[i]; - int len = static_cast(cls.literals.size()); + auto const size = static_cast(cls.literals.size()); printf("%2d: ", i); - for (int j = 0; j < len; j += 2) { + for (int j = 0; j < size; j += 2) { char32_t c1 = cls.literals[j]; char32_t c2 = cls.literals[j + 1]; - if (c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127) + if (c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127) { printf("0x%02x-0x%02x", static_cast(c1), static_cast(c2)); - else + } else { printf("%c-%c", static_cast(c1), static_cast(c2)); - if ((j + 2) < len) printf(", "); + } + if ((j + 2) < size) { printf(", "); } } printf("\n"); if (cls.builtins) { @@ -1024,7 +1085,7 @@ void reprog::print(regex_flags const flags) } printf("\n"); } - if (_num_capturing_groups) printf("Number of capturing groups: %d\n", _num_capturing_groups); + if (_num_capturing_groups) { printf("Number of capturing groups: %d\n", _num_capturing_groups); } } #endif diff --git a/cpp/src/strings/regex/regcomp.h b/cpp/src/strings/regex/regcomp.h index 3131767de59..18735d0f980 100644 --- a/cpp/src/strings/regex/regcomp.h +++ b/cpp/src/strings/regex/regcomp.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -116,14 +116,19 @@ class reprog { void optimize1(); void optimize2(); + void check_for_errors(); +#ifndef NDEBUG void print(regex_flags const flags); +#endif private: std::vector _insts; std::vector _classes; int32_t _startinst_id; std::vector _startinst_ids; // short-cut to speed-up ORs - int32_t _num_capturing_groups; + int32_t _num_capturing_groups{}; + + void check_for_errors(int32_t id, int32_t next_id); }; } // namespace detail diff --git a/cpp/tests/strings/contains_tests.cpp b/cpp/tests/strings/contains_tests.cpp index 48c4aac9e8a..12a00aa35ab 100644 --- a/cpp/tests/strings/contains_tests.cpp +++ b/cpp/tests/strings/contains_tests.cpp @@ -274,6 +274,15 @@ TEST_F(StringsContainsTests, EmbeddedNullCharacter) CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected); } +TEST_F(StringsContainsTests, Errors) +{ + cudf::test::strings_column_wrapper input({"3", "33"}); + auto strings_view = cudf::strings_column_view(input); + + EXPECT_THROW(cudf::strings::contains_re(strings_view, "(3?)+"), cudf::logic_error); + EXPECT_THROW(cudf::strings::contains_re(strings_view, "3?+"), cudf::logic_error); +} + TEST_F(StringsContainsTests, CountTest) { std::vector h_strings{