Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hex digits in character classes and escaped characters in character class ranges #5532

Merged
Merged
20 changes: 18 additions & 2 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,17 +872,33 @@ def test_regexp_extract_idx_0():
'regexp_extract(a, "^([a-d]*)[0-9]*([a-d]*)\\z", 0)'),
conf=_regexp_conf)

def test_character_classes():
gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}[ \n\t\r]{0,2}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'rlike(a, "[abcd]")',
'rlike(a, "[^\n\r]")',
'rlike(a, "[\n-\\]")',
'rlike(a, "[+--]")',
'regexp_extract(a, "[123]", 0)',
'regexp_replace(a, "[\\\\x41-\\\\x5a]", "@")',
),
conf=_regexp_conf)

def test_regexp_hexadecimal_digits():
gen = mk_str_gen(
'[abcd]\\\\x00\\\\x7f\\\\x80\\\\xff\\\\x{10ffff}\\\\x{00eeee}[\\\\xa0-\\\\xb0][abcd]')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'rlike(a, "\\\\x7f")',
'rlike(a, "\\\\x80")',
'rlike(a, "[\\\\xa0-\\\\xf0]")',
'rlike(a, "\\\\x{00eeee}")',
'regexp_extract(a, "([a-d]+)\\\\xa0([a-d]+)", 1)',
'regexp_replace(a, "\\\\xff", "")',
'regexp_replace(a, "\\\\x{10ffff}", "")',
'regexp_extract(a, "([a-d]+)[\\\\xa0\nabcd]([a-d]+)", 1)',
'regexp_replace(a, "\\\\xff", "@")',
'regexp_replace(a, "[\\\\xa0-\\\\xb0]", "@")',
'regexp_replace(a, "\\\\x{10ffff}", "@")',
),
conf=_regexp_conf)

Expand Down
105 changes: 72 additions & 33 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,41 @@ class RegexParser(pattern: String) {
}

private def parseCharacterClass(): RegexCharacterClass = {
val supportedMetaCharacters = "\\^-]+"
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved

def getEscapedComponent(): RegexCharacterClassComponent = {
peek() match {
case Some('x') =>
consumeExpected('x')
RegexChar(Integer.parseInt(parseHexDigit.a, 16).toChar)
case Some('0') => throw new RegexUnsupportedException(
"cuDF does not support octal digits in character classes")
case Some(ch) =>
consumeExpected(ch) match {
// List of character literals with an escape from here, under "Characters"
// https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html
case 'n' => RegexChar('\n')
case 'r' => RegexChar('\r')
case 't' => RegexChar('\t')
case 'f' => RegexChar('\f')
case 'a' => RegexChar('\u0007')
case 'b' => RegexChar('\b')
case 'e' => RegexChar('\u001b')
case ch =>
if (supportedMetaCharacters.contains(ch)) {
// an escaped metacharacter ('\\', '^', '-', ']', '+')
RegexEscaped(ch)
} else {
throw new RegexUnsupportedException(
s"Unsupported escaped character in character class", Some(pos))
}
}
case None =>
throw new RegexUnsupportedException(
s"Unclosed character class", Some(pos))
}
}

val start = pos
val characterClass = RegexCharacterClass(negated = false, characters = ListBuffer())
// loop until the end of the character class or EOF
Expand All @@ -185,44 +220,43 @@ class RegexParser(pattern: String) {
// Negates the character class, causing it to match a single character not listed in
// the character class. Only valid immediately after the opening '['
characterClass.negated = true
case '\n' | '\r' | '\t' | '\b' | '\f' | '\u0007' =>
// treat as a literal character and add to the character class
characterClass.append(ch)
case '\\' =>
peek() match {
case None =>
throw new RegexUnsupportedException(
s"Unclosed character class", Some(pos))
case Some(ch) =>
// typically an escaped metacharacter ('\\', '^', '-', ']', '+')
// within the character class, but could be any escaped character
characterClass.appendEscaped(consumeExpected(ch))
}
case '\u0000' =>
throw new RegexUnsupportedException(
"cuDF does not support null characters in regular expressions", Some(pos))
case _ =>
// check for range
val start = ch
case ch =>
val nextChar: RegexCharacterClassComponent = ch match {
case '\\' =>
getEscapedComponent() match {
case RegexChar(ch) if supportedMetaCharacters.contains(ch) =>
// A hex or octal representation of a meta character gets treated as an escaped
// char. Example: [\x5ea] is treated as [\^a], not just [^a]
RegexEscaped(ch)
case other => other
}
case ch =>
RegexChar(ch)
}
peek() match {
case Some('-') =>
consumeExpected('-')
peek() match {
case Some(']') =>
// '-' at end of class e.g. "[abc-]"
characterClass.append(ch)
characterClass.append(nextChar)
characterClass.append('-')
case Some('\\') =>
consumeExpected('\\')
characterClass.appendRange(nextChar, getEscapedComponent())
case Some(end) =>
skip()
characterClass.appendRange(start, end)
characterClass.appendRange(nextChar, RegexChar(end))
case _ =>
throw new RegexUnsupportedException(
"unexpected EOF while parsing character range",
Some(pos))
}
case _ =>
// treat as supported literal character
characterClass.append(ch)
characterClass.append(nextChar)
}
}
}
Expand Down Expand Up @@ -393,6 +427,9 @@ class RegexParser(pattern: String) {
val value = Integer.parseInt(hexDigit, 16)
if (value < Character.MIN_CODE_POINT || value > Character.MAX_CODE_POINT) {
throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit")
} else if (value == 0) {
throw new RegexUnsupportedException(s"cuDF does not support null characters " +
s"in regular expressions", Some(pos))
}

RegexHexDigit(hexDigit)
Expand Down Expand Up @@ -796,15 +833,16 @@ class CudfRegexTranspiler(mode: RegexMode) {
// cuDF is not compatible with Java for \d so we transpile to Java's definition
// of [0-9]
// https://github.com/rapidsai/cudf/issues/10894
RegexCharacterClass(negated = false, ListBuffer(RegexCharacterRange('0', '9')))
RegexCharacterClass(negated = false, ListBuffer(
RegexCharacterRange(RegexChar('0'), RegexChar('9'))))
case 'w' =>
// cuDF is not compatible with Java for \w so we transpile to Java's definition
// of `[a-zA-Z_0-9]`
RegexCharacterClass(negated = false, ListBuffer(
RegexCharacterRange('a', 'z'),
RegexCharacterRange('A', 'Z'),
RegexCharacterRange(RegexChar('a'), RegexChar('z')),
RegexCharacterRange(RegexChar('A'), RegexChar('Z')),
RegexChar('_'),
RegexCharacterRange('0', '9')))
RegexCharacterRange(RegexChar('0'), RegexChar('9'))))
case 'D' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4475
throw new RegexUnsupportedException("non-digit class \\D is not supported")
Expand Down Expand Up @@ -852,7 +890,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
RegexChar('\u202f'), RegexChar('\u205f'), RegexChar('\u3000')
)
chars += RegexEscaped('t')
chars += RegexCharacterRange('\u2000', '\u200a')
chars += RegexCharacterRange(RegexChar('\u2000'), RegexChar('\u200a'))
if (ch.isUpper) {
negateCharacterClass(chars)
} else {
Expand Down Expand Up @@ -901,11 +939,6 @@ class CudfRegexTranspiler(mode: RegexMode) {
// - "[\02] should match the character with code point 2"
throw new RegexUnsupportedException(
"cuDF does not support octal digits in character classes")
case RegexEscaped(ch) if ch == 'x' =>
// examples
// - "[\x02] should match the character with code point 2"
throw new RegexUnsupportedException(
"cuDF does not support hex digits in character classes")
case _ =>
}
val components: Seq[RegexCharacterClassComponent] = characters
Expand Down Expand Up @@ -1268,10 +1301,11 @@ sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{
override def toRegexString: String = s"\\$a"
}

sealed case class RegexCharacterRange(start: Char, end: Char)
sealed case class RegexCharacterRange(start: RegexCharacterClassComponent,
end: RegexCharacterClassComponent)
extends RegexCharacterClassComponent{
override def children(): Seq[RegexAST] = Seq.empty
override def toRegexString: String = s"$start-$end"
override def toRegexString: String = s"${start.toRegexString}-${end.toRegexString}"
}

sealed case class RegexCharacterClass(
Expand All @@ -1284,11 +1318,16 @@ sealed case class RegexCharacterClass(
characters += RegexChar(ch)
}

def append(component: RegexCharacterClassComponent): Unit = {
characters += component
}

def appendEscaped(ch: Char): Unit = {
characters += RegexEscaped(ch)
}

def appendRange(start: Char, end: Char): Unit = {
def appendRange(start: RegexCharacterClassComponent,
end: RegexCharacterClassComponent): Unit = {
characters += RegexCharacterRange(start, end)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ class RegularExpressionParserSuite extends FunSuite {
RegexSequence(ListBuffer(
RegexCharacterClass(negated = false,
ListBuffer(
RegexCharacterRange('a', 'z'),
RegexCharacterRange(RegexChar('a'), RegexChar('z')),
RegexChar('+'),
RegexCharacterRange('A', 'Z'))))))
RegexCharacterRange(RegexChar('A'), RegexChar('Z')))))))
}

test("character class complex example") {
Expand Down Expand Up @@ -120,7 +120,7 @@ class RegularExpressionParserSuite extends FunSuite {
RegexSequence(ListBuffer(
RegexRepetition(
RegexCharacterClass(negated = false, ListBuffer(
RegexCharacterRange('A', 'Z'))),
RegexCharacterRange(RegexChar('A'), RegexChar('Z')))),
SimpleQuantifier('+')
)
))
Expand Down Expand Up @@ -224,29 +224,33 @@ class RegularExpressionParserSuite extends FunSuite {
RegexGroup(capture = true, RegexChoice(RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexSequence(ListBuffer(
RegexRepetition(RegexCharacterClass(negated = false, ListBuffer(
RegexCharacterRange('0', '9'))), SimpleQuantifier('+'))))))),
RegexCharacterRange(RegexChar('0'), RegexChar('9')))),
SimpleQuantifier('+'))))))),
RegexChoice(RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexSequence(ListBuffer(
RegexRepetition(
RegexCharacterClass(negated = false, ListBuffer(
RegexCharacterRange('0', '9'))), SimpleQuantifier('*')), RegexEscaped('.'),
RegexCharacterRange(RegexChar('0'), RegexChar('9')))),
SimpleQuantifier('*')), RegexEscaped('.'),
RegexRepetition(
RegexCharacterClass(negated = false, ListBuffer(RegexCharacterRange('0', '9'))),
RegexCharacterClass(negated = false, ListBuffer(
RegexCharacterRange(RegexChar('0'), RegexChar('9')))),
SimpleQuantifier('+'))))))), RegexSequence(ListBuffer(
RegexGroup(capture = true, RegexSequence(ListBuffer(
RegexRepetition(
RegexCharacterClass(negated = false, ListBuffer(RegexCharacterRange('0', '9'))),
RegexCharacterClass(negated = false, ListBuffer(
RegexCharacterRange(RegexChar('0'), RegexChar('9')))),
SimpleQuantifier('+')), RegexEscaped('.'),
RegexRepetition(RegexCharacterClass(negated = false,
ListBuffer(RegexCharacterRange('0', '9'))),
ListBuffer(RegexCharacterRange(RegexChar('0'), RegexChar('9')))),
SimpleQuantifier('*')))))))))),
RegexRepetition(
RegexGroup(capture = true, RegexSequence(ListBuffer(
RegexCharacterClass(negated = false, ListBuffer(RegexChar('e'), RegexChar('E'))),
RegexRepetition(RegexCharacterClass(negated = false,
ListBuffer(RegexChar('+'), RegexEscaped('-'))),SimpleQuantifier('?')),
RegexRepetition(RegexCharacterClass(negated = false,
ListBuffer(RegexCharacterRange('0', '9'))),
ListBuffer(RegexCharacterRange(RegexChar('0'), RegexChar('9')))),
SimpleQuantifier('+'))))), SimpleQuantifier('?')),
RegexRepetition(RegexCharacterClass(negated = false, ListBuffer(
RegexChar('f'), RegexChar('F'), RegexChar('d'), RegexChar('D'))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,6 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
)
}

test("cuDF does not support hex digits in character classes") {
// see https://github.com/NVIDIA/spark-rapids/issues/4865
val patterns = Seq(raw"[\x02]", raw"[\x2c]", raw"[\x7f]")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexFindMode,
"cuDF does not support hex digits in character classes"
)
)
}

test("octal digits - find") {
val patterns = Seq(raw"\07", raw"\077", raw"\0177", raw"\01772", raw"\0200",
raw"\0376", raw"\0377", raw"\02002")
Expand All @@ -183,6 +173,22 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpFind(patterns, Seq("", "\u0007", "a\u0007b",
"\u0007\u003f\u007f", "\u0080", "a\u00fe\u00ffb", "ab\ueeeecd"))
}

test("hex digit character classes") {
val patterns = Seq(raw"[\x02]", raw"[\x2c]", raw"[\x7f]", raw"[\x80]", raw"[\x01-\xff]",
raw"[a-\xff]", raw"[\x20-z]")
val inputs = Seq("", "\u007f", "a\u007fb", "\u007f\u003f\u007f", "\u0080", "a\u00fe\u00ffb",
"\u007f2", "abcd", "\u0000\u007f\u00ff\u0123\u0abc")
assertCpuGpuMatchesRegexpFind(patterns, inputs)
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
}

test("compare CPU and GPU: character range with escaped characters") {
val inputs = Seq("", "abc", "\r\n", "a[b\t\n \rc]d", "[\r+\n-\t[]")
assertCpuGpuMatchesRegexpFind(Seq(raw"[\r\n\t]", raw"[\t-\r]", raw"[\n-\\]"),
inputs)
assertCpuGpuMatchesRegexpReplace(Seq("[\t-\r]", "[\b-\t123\n]", raw"[\\u002d-\u007a]"),
inputs)
}

test("string anchors - find") {
val patterns = Seq("\\Atest", "\\A+test", "\\A{1}test", "\\A{1,}test",
Expand Down Expand Up @@ -896,33 +902,33 @@ class FuzzRegExp(suggestedChars: String, skipKnownIssues: Boolean = true) {
private def characterClassComponent = {
val baseGenerators = Seq[() => RegexCharacterClassComponent](
() => char,
() => charRange)
() => charRange,
() => hexDigit,
() => escapedChar)
val generators = if (skipKnownIssues) {
baseGenerators
} else {
baseGenerators ++ Seq(
() => escapedChar, // https://github.com/NVIDIA/spark-rapids/issues/4505
() => hexDigit, // https://github.com/NVIDIA/spark-rapids/issues/4865
() => octalDigit) // https://github.com/NVIDIA/spark-rapids/issues/4862
}
generators(rr.nextInt(generators.length))()
}

private def charRange: RegexCharacterClassComponent = {
val baseGenerators = Seq[() => RegexCharacterClassComponent](
() => RegexCharacterRange('a', 'z'),
() => RegexCharacterRange('A', 'Z'),
() => RegexCharacterRange('z', 'a'),
() => RegexCharacterRange('Z', 'A'),
() => RegexCharacterRange('0', '9'),
() => RegexCharacterRange('9', '0')
() => RegexCharacterRange(RegexChar('a'), RegexChar('z')),
() => RegexCharacterRange(RegexChar('A'), RegexChar('Z')),
() => RegexCharacterRange(RegexChar('z'), RegexChar('a')),
() => RegexCharacterRange(RegexChar('Z'), RegexChar('A')),
() => RegexCharacterRange(RegexChar('0'), RegexChar('9')),
() => RegexCharacterRange(RegexChar('9'), RegexChar('0'))
)
val generators = if (skipKnownIssues) {
baseGenerators
} else {
// we do not support escaped characters in character ranges yet
// see https://github.com/NVIDIA/spark-rapids/issues/4505
baseGenerators ++ Seq(() => RegexCharacterRange(char.ch, char.ch))
baseGenerators ++ Seq(() => RegexCharacterRange(char, char))
}
generators(rr.nextInt(generators.length))()
}
Expand Down