Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check for regex instructions causing an infinite-loop #10095

Merged
merged 6 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 100 additions & 39 deletions cpp/src/strings/regex/regcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ class regex_compiler {
m_prog.set_start_inst(andstack[andstack.size() - 1].id_first);
m_prog.optimize1();
m_prog.optimize2();
m_prog.check_for_errors();
m_prog.set_groups_count(cursubid);
}
};
Expand Down Expand Up @@ -926,90 +927,150 @@ void reprog::optimize2()
_startinst_ids.push_back(-1); // terminator mark
}

/**
* @brief Check a specific instruction for errors.
*
* Currently this is checking for an infinite-loop condition as documented in this issue:
* https://github.com/rapidsai/cudf/issues/10006
*
* Example instructions list created from pattern `(A?)+`
* ```
* 0: CHAR c='A', next=2
* 1: OR right=0, left=2, next=2
* 2: RBRA id=1, next=4
* 3: LBRA id=1, next=1
* 4: OR right=3, left=5, next=5
* 5: END
* ```
*
* Following the example above, the instruction at `id==1` (OR)
* is being checked. If the instruction path returns to `id==1`
* without including the `0==CHAR` or `5==END` as in this example,
* then this would cause the runtime to go into an infinite-loop.
*
* It appears this example pattern is not valid. But Python interprets
* its behavior similarly to pattern `(A*)`. Handling this in the same
* way does not look feasible with the current implementation.
*
* @throw cudf::logic_error if instruction logic error is found
*
* @param id Instruction to check if repeated.
* @param next_id Next instruction to process.
*/
void reprog::check_for_errors(int32_t id, int32_t next_id)
{
auto inst = inst_at(next_id);
while (inst.type == LBRA || inst.type == RBRA) {
next_id = inst.u2.next_id;
inst = inst_at(next_id);
}
if (inst.type == OR) {
CUDF_EXPECTS(next_id != id, "Unsupported regex pattern");
check_for_errors(id, inst.u2.left_id);
check_for_errors(id, inst.u1.right_id);
}
}

/**
* @brief Check regex instruction set for any errors.
*
* Currently, this checks for OR instructions that eventually point back to themselves with only
* intervening capture group instructions between causing an infinite-loop during runtime
* evaluation.
*/
void reprog::check_for_errors()
{
for (auto id = 0; id < insts_count(); ++id) {
auto const inst = inst_at(id);
if (inst.type == OR) {
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
check_for_errors(id, inst.u2.left_id);
check_for_errors(id, inst.u1.right_id);
}
}
}

#ifndef NDEBUG
void reprog::print(regex_flags const flags)
{
printf("Flags = 0x%08x\n", static_cast<uint32_t>(flags));
printf("Instructions:\n");
for (std::size_t i = 0; i < _insts.size(); i++) {
const reinst& inst = _insts[i];
printf("%zu :", i);
printf("%3zu: ", i);
switch (inst.type) {
default: printf("Unknown instruction: %d, nextid= %d", inst.type, inst.u2.next_id); break;
default: printf("Unknown instruction: %d, next=%d", inst.type, inst.u2.next_id); break;
case CHAR:
if (inst.u1.c <= 32 || inst.u1.c >= 127)
printf(
"CHAR, c = '0x%02x', nextid= %d", static_cast<unsigned>(inst.u1.c), inst.u2.next_id);
else
printf("CHAR, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id);
if (inst.u1.c <= 32 || inst.u1.c >= 127) {
printf(" CHAR c='0x%02x', next=%d", static_cast<unsigned>(inst.u1.c), inst.u2.next_id);
} else {
printf(" CHAR c='%c', next=%d", inst.u1.c, inst.u2.next_id);
}
break;
case RBRA: printf("RBRA, subid= %d, nextid= %d", inst.u1.subid, inst.u2.next_id); break;
case LBRA: printf("LBRA, subid= %d, nextid= %d", inst.u1.subid, inst.u2.next_id); break;
case RBRA: printf(" RBRA id=%d, next=%d", inst.u1.subid, inst.u2.next_id); break;
case LBRA: printf(" LBRA id=%d, next=%d", inst.u1.subid, inst.u2.next_id); break;
case OR:
printf("OR, rightid=%d, leftid=%d, nextid=%d",
inst.u1.right_id,
inst.u2.left_id,
inst.u2.next_id);
printf(
" OR right=%d, left=%d, next=%d", inst.u1.right_id, inst.u2.left_id, inst.u2.next_id);
break;
case STAR: printf("STAR, nextid= %d", inst.u2.next_id); break;
case PLUS: printf("PLUS, nextid= %d", inst.u2.next_id); break;
case QUEST: printf("QUEST, nextid= %d", inst.u2.next_id); break;
case ANY: printf("ANY, nextid= %d", inst.u2.next_id); break;
case ANYNL: printf("ANYNL, nextid= %d", inst.u2.next_id); break;
case NOP: printf("NOP, nextid= %d", inst.u2.next_id); break;
case STAR: printf(" STAR next=%d", inst.u2.next_id); break;
case PLUS: printf(" PLUS next=%d", inst.u2.next_id); break;
case QUEST: printf(" QUEST next=%d", inst.u2.next_id); break;
case ANY: printf(" ANY next=%d", inst.u2.next_id); break;
case ANYNL: printf(" ANYNL next=%d", inst.u2.next_id); break;
case NOP: printf(" NOP next=%d", inst.u2.next_id); break;
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
case BOL: {
printf("BOL, c = ");
printf(" BOL c=");
if (inst.u1.c == '\n') {
printf("'\\n'");
} else {
printf("'%c'", inst.u1.c);
}
printf(", nextid= %d", inst.u2.next_id);
printf(", next=%d", inst.u2.next_id);
break;
}
case EOL: {
printf("EOL, c = ");
printf(" EOL c=");
if (inst.u1.c == '\n') {
printf("'\\n'");
} else {
printf("'%c'", inst.u1.c);
}
printf(", nextid= %d", inst.u2.next_id);
printf(", next=%d", inst.u2.next_id);
break;
}
case CCLASS: printf("CCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id); break;
case NCCLASS:
printf("NCCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id);
break;
case BOW: printf("BOW, nextid= %d", inst.u2.next_id); break;
case NBOW: printf("NBOW, nextid= %d", inst.u2.next_id); break;
case END: printf("END"); break;
case CCLASS: printf(" CCLASS cls=%d , next=%d", inst.u1.cls_id, inst.u2.next_id); break;
case NCCLASS: printf("NCCLASS cls=%d, next=%d", inst.u1.cls_id, inst.u2.next_id); break;
case BOW: printf(" BOW next=%d", inst.u2.next_id); break;
case NBOW: printf(" NBOW next=%d", inst.u2.next_id); break;
case END: printf(" END"); break;
}
printf("\n");
}

printf("startinst_id=%d\n", _startinst_id);
if (_startinst_ids.size() > 0) {
printf("startinst_ids:");
for (size_t i = 0; i < _startinst_ids.size(); i++)
printf("startinst_ids: [");
for (size_t i = 0; i < _startinst_ids.size(); i++) {
printf(" %d", _startinst_ids[i]);
printf("\n");
}
printf("]\n");
}

int count = static_cast<int>(_classes.size());
printf("\nClasses %d\n", count);
for (int i = 0; i < count; i++) {
const reclass& cls = _classes[i];
int len = static_cast<int>(cls.literals.size());
auto const size = static_cast<int>(cls.literals.size());
printf("%2d: ", i);
for (int j = 0; j < len; j += 2) {
for (int j = 0; j < size; j += 2) {
char32_t c1 = cls.literals[j];
char32_t c2 = cls.literals[j + 1];
if (c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127)
if (c1 <= 32 || c1 >= 127 || c2 <= 32 || c2 >= 127) {
printf("0x%02x-0x%02x", static_cast<unsigned>(c1), static_cast<unsigned>(c2));
else
} else {
printf("%c-%c", static_cast<char>(c1), static_cast<char>(c2));
if ((j + 2) < len) printf(", ");
}
if ((j + 2) < size) { printf(", "); }
}
printf("\n");
if (cls.builtins) {
Expand All @@ -1024,7 +1085,7 @@ void reprog::print(regex_flags const flags)
}
printf("\n");
}
if (_num_capturing_groups) printf("Number of capturing groups: %d\n", _num_capturing_groups);
if (_num_capturing_groups) { printf("Number of capturing groups: %d\n", _num_capturing_groups); }
}
#endif

Expand Down
9 changes: 7 additions & 2 deletions cpp/src/strings/regex/regcomp.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -116,14 +116,19 @@ class reprog {

void optimize1();
void optimize2();
void check_for_errors();
#ifndef NDEBUG
void print(regex_flags const flags);
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
#endif

private:
std::vector<reinst> _insts;
std::vector<reclass> _classes;
int32_t _startinst_id;
std::vector<int32_t> _startinst_ids; // short-cut to speed-up ORs
int32_t _num_capturing_groups;
int32_t _num_capturing_groups{};

void check_for_errors(int32_t id, int32_t next_id);
};

} // namespace detail
Expand Down
9 changes: 9 additions & 0 deletions cpp/tests/strings/contains_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,15 @@ TEST_F(StringsContainsTests, EmbeddedNullCharacter)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);
}

TEST_F(StringsContainsTests, Errors)
{
cudf::test::strings_column_wrapper input({"3", "33"});
auto strings_view = cudf::strings_column_view(input);

EXPECT_THROW(cudf::strings::contains_re(strings_view, "(3?)+"), cudf::logic_error);
EXPECT_THROW(cudf::strings::contains_re(strings_view, "3?+"), cudf::logic_error);
vyasr marked this conversation as resolved.
Show resolved Hide resolved
}

TEST_F(StringsContainsTests, CountTest)
{
std::vector<const char*> h_strings{
Expand Down