Skip to content

Commit

Permalink
Add check for regex instructions causing an infinite-loop (#10095)
Browse files Browse the repository at this point in the history
Closes #10006 

Fixes a use case where the regex pattern creates a set of instructions that can cause the regex evaluation process to go into an infinite loop. For example, the pattern `(x?)+` creates the following instructions:

```
Instructions:
  0:    CHAR c='x', next=2
  1:      OR right=0, left=2, next=2
  2:    RBRA id=1, next=4
  3:    LBRA id=1, next=1
  4:      OR right=3, left=5, next=5
  5:     END
startinst_id=3
startinst_ids: [ 3 -1]
```

This causes in an infinite loop at instruction 4 where the path may go like: 4->3->1->2->4 ... forever.
Supporting this pattern does not look possible. The `+` quantifier is applied to capture group symbol `)` inside of which `x?` means 0 or more repeating the character `x`. This means it could match `x` or nothing and so applying the `+` to nothing would be invalid. That said, the pattern `x?+` currently already throws an error because of the invalid usage of `+` quantifier.

Therefore, the fix here adds a checking step after the instruction set is created to check for a possible infinite-loop case. If one is detected, an exception is thrown indicating the pattern is not supported.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Devavret Makkar (https://github.com/devavret)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #10095
  • Loading branch information
davidwendt authored Jan 27, 2022
1 parent 7c69dae commit b2d5874
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 41 deletions.
139 changes: 100 additions & 39 deletions cpp/src/strings/regex/regcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ class regex_compiler {
m_prog.set_start_inst(andstack[andstack.size() - 1].id_first);
m_prog.optimize1();
m_prog.optimize2();
m_prog.check_for_errors();
m_prog.set_groups_count(cursubid);
}
};
Expand Down Expand Up @@ -926,90 +927,150 @@ void reprog::optimize2()
_startinst_ids.push_back(-1); // terminator mark
}

/**
* @brief Check a specific instruction for errors.
*
* Currently this is checking for an infinite-loop condition as documented in this issue:
* https://github.com/rapidsai/cudf/issues/10006
*
* Example instructions list created from pattern `(A?)+`
* ```
* 0: CHAR c='A', next=2
* 1: OR right=0, left=2, next=2
* 2: RBRA id=1, next=4
* 3: LBRA id=1, next=1
* 4: OR right=3, left=5, next=5
* 5: END
* ```
*
* Following the example above, the instruction at `id==1` (OR)
* is being checked. If the instruction path returns to `id==1`
* without including the `0==CHAR` or `5==END` as in this example,
* then this would cause the runtime to go into an infinite-loop.
*
* It appears this example pattern is not valid. But Python interprets
* its behavior similarly to pattern `(A*)`. Handling this in the same
* way does not look feasible with the current implementation.
*
* @throw cudf::logic_error if instruction logic error is found
*
* @param id Instruction to check if repeated.
* @param next_id Next instruction to process.
*/
void reprog::check_for_errors(int32_t id, int32_t next_id)
{
auto inst = inst_at(next_id);
while (inst.type == LBRA || inst.type == RBRA) {
next_id = inst.u2.next_id;
inst = inst_at(next_id);
}
if (inst.type == OR) {
CUDF_EXPECTS(next_id != id, "Unsupported regex pattern");
check_for_errors(id, inst.u2.left_id);
check_for_errors(id, inst.u1.right_id);
}
}

/**
* @brief Check regex instruction set for any errors.
*
* Currently, this checks for OR instructions that eventually point back to themselves with only
* intervening capture group instructions between causing an infinite-loop during runtime
* evaluation.
*/
void reprog::check_for_errors()
{
for (auto id = 0; id < insts_count(); ++id) {
auto const inst = inst_at(id);
if (inst.type == OR) {
check_for_errors(id, inst.u2.left_id);
check_for_errors(id, inst.u1.right_id);
}
}
}

#ifndef NDEBUG
void reprog::print(regex_flags const flags)
{
printf("Flags = 0x%08x\n", static_cast<uint32_t>(flags));
printf("Instructions:\n");
for (std::size_t i = 0; i < _insts.size(); i++) {
const reinst& inst = _insts[i];
printf("%zu :", i);
printf("%3zu: ", i);
switch (inst.type) {
default: printf("Unknown instruction: %d, nextid= %d", inst.type, inst.u2.next_id); break;
default: printf("Unknown instruction: %d, next=%d", inst.type, inst.u2.next_id); break;
case CHAR:
if (inst.u1.c <= 32 || inst.u1.c >= 127)
printf(
"CHAR, c = '0x%02x', nextid= %d", static_cast<unsigned>(inst.u1.c), inst.u2.next_id);
else
printf("CHAR, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id);
if (inst.u1.c <= 32 || inst.u1.c >= 127) {
printf(" CHAR c='0x%02x', next=%d", static_cast<unsigned>(inst.u1.c), inst.u2.next_id);
} else {
printf(" CHAR c='%c', next=%d", inst.u1.c, inst.u2.next_id);
}
break;
case RBRA: printf("RBRA, subid= %d, nextid= %d", inst.u1.subid, inst.u2.next_id); break;
case LBRA: printf("LBRA, subid= %d, nextid= %d", inst.u1.subid, inst.u2.next_id); break;
case RBRA: printf(" RBRA id=%d, next=%d", inst.u1.subid, inst.u2.next_id); break;
case LBRA: printf(" LBRA id=%d, next=%d", inst.u1.subid, inst.u2.next_id); break;
case OR:
printf("OR, rightid=%d, leftid=%d, nextid=%d",
inst.u1.right_id,
inst.u2.left_id,
inst.u2.next_id);
printf(
" OR right=%d, left=%d, next=%d", inst.u1.right_id, inst.u2.left_id, inst.u2.next_id);
break;
case STAR: printf("STAR, nextid= %d", inst.u2.next_id); break;
case PLUS: printf("PLUS, nextid= %d", inst.u2.next_id); break;
case QUEST: printf("QUEST, nextid= %d", inst.u2.next_id); break;
case ANY: printf("ANY, nextid= %d", inst.u2.next_id); break;
case ANYNL: printf("ANYNL, nextid= %d", inst.u2.next_id); break;
case NOP: printf("NOP, nextid= %d", inst.u2.next_id); break;
case STAR: printf(" STAR next=%d", inst.u2.next_id); break;
case PLUS: printf(" PLUS next=%d", inst.u2.next_id); break;
case QUEST: printf(" QUEST next=%d", inst.u2.next_id); break;
case ANY: printf(" ANY next=%d", inst.u2.next_id); break;
case ANYNL: printf(" ANYNL next=%d", inst.u2.next_id); break;
case NOP: printf(" NOP next=%d", inst.u2.next_id); break;
case BOL: {
printf("BOL, c = ");
printf(" BOL c=");
if (inst.u1.c == '\n') {
printf("'\\n'");
} else {
printf("'%c'", inst.u1.c);
}
printf(", nextid= %d", inst.u2.next_id);
printf(", next=%d", inst.u2.next_id);
break;
}
case EOL: {
printf("EOL, c = ");
printf(" EOL c=");
if (inst.u1.c == '\n') {
printf("'\\n'");
} else {
printf("'%c'", inst.u1.c);
}
printf(", nextid= %d", inst.u2.next_id);
printf(", next=%d", inst.u2.next_id);
break;
}
case CCLASS: printf("CCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id); break;
case NCCLASS:
printf("NCCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id);
break;
case BOW: printf("BOW, nextid= %d", inst.u2.next_id); break;
case NBOW: printf("NBOW, nextid= %d", inst.u2.next_id); break;
case END: printf("END"); break;
case CCLASS: printf(" CCLASS cls=%d , next=%d", inst.u1.cls_id, inst.u2.next_id); break;
case NCCLASS: printf("NCCLASS cls=%d, next=%d", inst.u1.cls_id, inst.u2.next_id); break;
case BOW: printf(" BOW next=%d", inst.u2.next_id); break;
case NBOW: printf(" NBOW next=%d", inst.u2.next_id); break;
case END: printf(" END"); break;
}
printf("\n");
}

printf("startinst_id=%d\n", _startinst_id);
if (_startinst_ids.size() > 0) {
printf("startinst_ids:");
for (size_t i = 0; i < _startinst_ids.size(); i++)
printf("startinst_ids: [");
for (size_t i = 0; i < _startinst_ids.size(); i++) {
printf(" %d", _startinst_ids[i]);
printf("\n");
}
printf("]\n");
}

int count = static_cast<int>(_classes.size());
printf("\nClasses %d\n", count);
for (int i = 0; i < count; i++) {
const reclass& cls = _classes[i];
int len = static_cast<int>(cls.literals.size());
auto const size = static_cast<int>(cls.literals.size());
printf("%2d: ", i);
for (int j = 0; j < len; j += 2) {
for (int j = 0; j < size; j += 2) {
char32_t c1 = cls.literals[j];
char32_t c2 = cls.literals[j + 1];
if (c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127)
if (c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127) {
printf("0x%02x-0x%02x", static_cast<unsigned>(c1), static_cast<unsigned>(c2));
else
} else {
printf("%c-%c", static_cast<char>(c1), static_cast<char>(c2));
if ((j + 2) < len) printf(", ");
}
if ((j + 2) < size) { printf(", "); }
}
printf("\n");
if (cls.builtins) {
Expand All @@ -1024,7 +1085,7 @@ void reprog::print(regex_flags const flags)
}
printf("\n");
}
if (_num_capturing_groups) printf("Number of capturing groups: %d\n", _num_capturing_groups);
if (_num_capturing_groups) { printf("Number of capturing groups: %d\n", _num_capturing_groups); }
}
#endif

Expand Down
9 changes: 7 additions & 2 deletions cpp/src/strings/regex/regcomp.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -116,14 +116,19 @@ class reprog {

void optimize1();
void optimize2();
void check_for_errors();
#ifndef NDEBUG
void print(regex_flags const flags);
#endif

private:
std::vector<reinst> _insts;
std::vector<reclass> _classes;
int32_t _startinst_id;
std::vector<int32_t> _startinst_ids; // short-cut to speed-up ORs
int32_t _num_capturing_groups;
int32_t _num_capturing_groups{};

void check_for_errors(int32_t id, int32_t next_id);
};

} // namespace detail
Expand Down
9 changes: 9 additions & 0 deletions cpp/tests/strings/contains_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,15 @@ TEST_F(StringsContainsTests, EmbeddedNullCharacter)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);
}

TEST_F(StringsContainsTests, Errors)
{
cudf::test::strings_column_wrapper input({"3", "33"});
auto strings_view = cudf::strings_column_view(input);

EXPECT_THROW(cudf::strings::contains_re(strings_view, "(3?)+"), cudf::logic_error);
EXPECT_THROW(cudf::strings::contains_re(strings_view, "3?+"), cudf::logic_error);
}

TEST_F(StringsContainsTests, CountTest)
{
std::vector<const char*> h_strings{
Expand Down

0 comments on commit b2d5874

Please sign in to comment.