Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regex parser improvements and bug fixes #4087

Merged
merged 2 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ class RegexParser(pattern: String) {
private var pos = 0

def parse(): RegexAST = {
val ast = parseInternal()
val ast = parseUntil(() => eof())
if (!eof()) {
throw new RegexUnsupportedException("failed to parse full regex")
}
ast
}

private def parseInternal(): RegexAST = {
val term = parseTerm(() => peek().contains('|'))
private def parseUntil(until: () => Boolean): RegexAST = {
val term = parseTerm(() => until() || peek().contains('|'))
if (!eof() && peek().contains('|')) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Not needed as part of this PR, but I think peek() will return None at EOF, so this can be simplified to:

Suggested change
if (!eof() && peek().contains('|')) {
if (peek().contains('|')) {

Applies to a few other places where eof and peek are combined.

consumeExpected('|')
RegexChoice(term, parseInternal())
RegexChoice(term, parseUntil(until))
} else {
term
}
Expand All @@ -64,7 +64,7 @@ class RegexParser(pattern: String) {
private def parseTerm(until: () => Boolean): RegexAST = {
val sequence = RegexSequence(new ListBuffer())
while (!eof() && !until()) {
parseFactor() match {
parseFactor(until) match {
case RegexSequence(parts) =>
sequence.parts ++= parts
case other =>
Expand All @@ -89,9 +89,10 @@ class RegexParser(pattern: String) {
}
}

private def parseFactor(): RegexAST = {
private def parseFactor(until: () => Boolean): RegexAST = {
var base = parseBase()
while (!eof() && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?')
while (!eof() && !until()
&& (peek().exists(ch => ch == '*' || ch == '+' || ch == '?')
Comment on lines +94 to +95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per previous nit, should be able to simplify this to something like:

Suggested change
while (!eof() && !until()
&& (peek().exists(ch => ch == '*' || ch == '+' || ch == '?')
while (!until() && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?')

|| isValidQuantifierAhead())) {

val quantifier = if (peek().contains('{')) {
Expand All @@ -116,15 +117,26 @@ class RegexParser(pattern: String) {
case '\u0000' =>
throw new RegexUnsupportedException(
"cuDF does not support null characters in regular expressions", Some(pos))
case '*' | '+' | '?' =>
throw new RegexUnsupportedException(
"base expression cannot start with quantifier", Some(pos))
case other =>
RegexChar(other)
}
}

private def parseGroup(): RegexAST = {
val term = parseTerm(() => peek().contains(')'))
val captureGroup = if (pos + 1 < pattern.length
&& pattern.charAt(pos) == '?'
&& pattern.charAt(pos+1) == ':') {
pos += 2
false
} else {
true
}
val term = parseUntil(() => peek().contains(')'))
consumeExpected(')')
RegexGroup(term)
RegexGroup(captureGroup, term)
}

private def parseCharacterClass(): RegexCharacterClass = {
Expand All @@ -138,7 +150,9 @@ class RegexParser(pattern: String) {
case '[' =>
// treat as a literal character and add to the character class
characterClass.append(ch)
case ']' =>
case ']' if pos > start + 1 =>
// "[]" is not a valid character class
// "[]a]" is a valid character class containing the characters "]" and "a"
characterClassComplete = true
case '^' if pos == start + 1 =>
// Negates the character class, causing it to match a single character not listed in
Expand Down Expand Up @@ -523,12 +537,8 @@ class CudfRegexTranspiler {
RegexChoice(rewrite(l), rewrite(r))
}

case RegexGroup(term) => term match {
case RegexSequence(ListBuffer(RegexChar(ch))) if "?*+".contains(ch) =>
throw new RegexUnsupportedException(nothingToRepeat)
case _ =>
RegexGroup(rewrite(term))
}
case RegexGroup(capture, term) =>
RegexGroup(capture, rewrite(term))

case other =>
throw new RegexUnsupportedException(s"Unhandled expression in transpiler: $other")
Expand All @@ -551,9 +561,13 @@ sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST {
override def toRegexString: String = parts.map(_.toRegexString).mkString
}

sealed case class RegexGroup(term: RegexAST) extends RegexAST {
sealed case class RegexGroup(capture: Boolean, term: RegexAST) extends RegexAST {
override def children(): Seq[RegexAST] = Seq(term)
override def toRegexString: String = s"(${term.toRegexString})"
override def toRegexString: String = if (capture) {
s"(${term.toRegexString})"
} else {
s"(?:${term.toRegexString})"
}
}

sealed case class RegexChoice(a: RegexAST, b: RegexAST) extends RegexAST {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class RegularExpressionParserSuite extends FunSuite {
test("group") {
assert(parse("(a)(b)") ===
RegexSequence(ListBuffer(
RegexGroup(RegexSequence(ListBuffer(RegexChar('a')))),
RegexGroup(RegexSequence(ListBuffer(RegexChar('b')))))))
RegexGroup(capture = true, RegexSequence(ListBuffer(RegexChar('a')))),
RegexGroup(capture = true, RegexSequence(ListBuffer(RegexChar('b')))))))
}

test("character class") {
Expand All @@ -64,6 +64,19 @@ class RegularExpressionParserSuite extends FunSuite {
RegexCharacterRange('A', 'Z'))))))
}

test("character classes containing ']'") {
// "[]a]" is a valid character class containing ']' and 'a'
assert(parse("[]a]") ===
RegexSequence(ListBuffer(
RegexCharacterClass(negated = false,
ListBuffer(RegexChar(']'), RegexChar('a'))))))
// "[a]]" is a valid character class "[a]" followed by character ']'
assert(parse("[a]]") ===
RegexSequence(ListBuffer(
RegexCharacterClass(negated = false,
ListBuffer(RegexChar('a'))), RegexChar(']'))))
}

test("hex digit") {
assert(parse(raw"\xFF") ===
RegexSequence(ListBuffer(RegexHexDigit("FF"))))
Expand All @@ -85,6 +98,24 @@ class RegularExpressionParserSuite extends FunSuite {
RegexSequence(ListBuffer(RegexOctalChar("47"), RegexChar('7'))))
}

test("group containing choice with repetition") {
assert(parse("(\t+|a)") == RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexChoice(RegexSequence(ListBuffer(
RegexRepetition(RegexChar('\t'),SimpleQuantifier('+')))),
RegexSequence(ListBuffer(RegexChar('a'))))))))
}

test("group containing quantifier") {
val e = intercept[RegexUnsupportedException] {
parse("(?)")
}
assert(e.getMessage.startsWith("base expression cannot start with quantifier"))

assert(parse("(?:a?)") === RegexSequence(ListBuffer(
RegexGroup(capture = false, RegexSequence(ListBuffer(
RegexRepetition(RegexChar('a'), SimpleQuantifier('?'))))))))
}

test("complex expression") {
val ast = parse(
"^" + // start of line
Expand All @@ -105,51 +136,50 @@ class RegularExpressionParserSuite extends FunSuite {
"$" // end of line
)
assert(ast ===
RegexSequence(ListBuffer(
RegexChar('^'),
RegexRepetition(
RegexCharacterClass(negated = false,
ListBuffer(RegexChar('+'), RegexEscaped('-'))), SimpleQuantifier('?')),
RegexGroup(RegexSequence(ListBuffer(
RegexGroup(RegexSequence(ListBuffer(
RegexGroup(RegexSequence(ListBuffer(
RegexGroup(RegexSequence(ListBuffer(
RegexSequence(ListBuffer(RegexChar('^'),
RegexRepetition(RegexCharacterClass(negated = false, ListBuffer(
RegexChar('+'), RegexEscaped('-'))), SimpleQuantifier('?')),
RegexGroup(capture = true, RegexChoice(RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexChoice(RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexSequence(ListBuffer(
RegexRepetition(RegexCharacterClass(negated = false, ListBuffer(
RegexCharacterRange('0', '9'))), SimpleQuantifier('+'))))))),
RegexChoice(RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexSequence(ListBuffer(
RegexRepetition(
RegexCharacterClass(negated = false, ListBuffer(
RegexCharacterRange('0', '9'))), SimpleQuantifier('*')), RegexEscaped('.'),
RegexRepetition(
RegexCharacterClass(negated = false,
ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+'))))),
RegexChar('|'),
RegexGroup(RegexSequence(ListBuffer(
RegexCharacterClass(negated = false, ListBuffer(RegexCharacterRange('0', '9'))),
SimpleQuantifier('+'))))))), RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexSequence(ListBuffer(
RegexRepetition(
RegexCharacterClass(negated = false,
ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('*')),
RegexEscaped('.'),
RegexCharacterClass(negated = false, ListBuffer(RegexCharacterRange('0', '9'))),
SimpleQuantifier('+')), RegexEscaped('.'),
RegexRepetition(RegexCharacterClass(negated = false,
ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+'))))),
RegexChar('|'),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we were detecting | as a character and not as a choice delimiter

RegexGroup(RegexSequence(ListBuffer(RegexRepetition(
RegexCharacterClass(negated = false,
ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+')),
RegexEscaped('.'),
ListBuffer(RegexCharacterRange('0', '9'))),
SimpleQuantifier('*')))))))))),
RegexRepetition(
RegexCharacterClass(negated = false,
ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('*')))))))),
RegexRepetition(RegexGroup(RegexSequence(ListBuffer(
RegexCharacterClass(negated = false,
ListBuffer(RegexChar('e'),
RegexChar('E'))),
RegexGroup(capture = true, RegexSequence(ListBuffer(
RegexCharacterClass(negated = false, ListBuffer(RegexChar('e'), RegexChar('E'))),
RegexRepetition(RegexCharacterClass(negated = false,
ListBuffer(RegexChar('+'), RegexEscaped('-'))),SimpleQuantifier('?')),
RegexRepetition(RegexCharacterClass(negated = false,
ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+'))))),
SimpleQuantifier('?')),
RegexRepetition(RegexCharacterClass(negated = false,
ListBuffer(RegexChar('f'), RegexChar('F'),
RegexChar('d'), RegexChar('D'))),SimpleQuantifier('?'))))),
RegexChar('|'), RegexChar('I'), RegexChar('n'), RegexChar('f'), RegexChar('|'),
RegexCharacterClass(negated = false, ListBuffer(RegexChar('n'), RegexChar('N'))),
RegexCharacterClass(negated = false, ListBuffer(RegexChar('a'), RegexChar('A'))),
RegexCharacterClass(negated = false, ListBuffer(RegexChar('n'), RegexChar('N')))))
),
ListBuffer(RegexCharacterRange('0', '9'))),
SimpleQuantifier('+'))))), SimpleQuantifier('?')),
RegexRepetition(RegexCharacterClass(negated = false, ListBuffer(
RegexChar('f'), RegexChar('F'), RegexChar('d'), RegexChar('D'))),
SimpleQuantifier('?'))))))),
RegexChoice(RegexSequence(ListBuffer(
RegexChar('I'), RegexChar('n'), RegexChar('f'))),
RegexSequence(ListBuffer(
RegexCharacterClass(negated = false,
ListBuffer(RegexChar('n'), RegexChar('N'))),
RegexCharacterClass(negated = false,
ListBuffer(RegexChar('a'), RegexChar('A'))),
RegexCharacterClass(negated = false,
ListBuffer(RegexChar('n'), RegexChar('N')))))))),
RegexChar('$'))))
}

Expand Down