Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sperlingxx committed Apr 1, 2022
1 parent a690fce commit 49a29a3
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
10 changes: 5 additions & 5 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,13 +772,13 @@ def test_regexp_extract_idx_0():
conf=_regexp_conf)

def test_regexp_whitespace():
gen = mk_str_gen('[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f{1,3}')
gen = mk_str_gen('\u001e[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f[0-9]{0,10}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'regexp(a, "\\s{2}")',
'regexp(a, "\\s{3}")',
'regexp(a, "[abcd]+\\s+[0-9]+")',
'regexp(a, "\\S{3}")',
'rlike(a, "\\s{2}")',
'rlike(a, "\\s{3}")',
'rlike(a, "[abcd]+\\s+[0-9]+")',
'rlike(a, "\\S{3}")',
'rlike(a, "[abcd]+\\s+\\S{2,3}")',
'regexp_extract(a, "([a-d]+)([0-9\\s]+)([a-d]+)", 2)',
'regexp_extract(a, "([a-d]+)(\\S+)([0-9]+)", 2)',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,6 @@ class CudfRegexTranspiler(mode: RegexMode) {
case 'b' | 'B' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4517
throw new RegexUnsupportedException("word boundaries are not supported")
case 's' | 'S' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4528
throw new RegexUnsupportedException("whitespace classes are not supported")
case 'A' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\A is not supported in split mode")
case 'Z' if mode == RegexSplitMode =>
Expand All @@ -604,6 +601,11 @@ class CudfRegexTranspiler(mode: RegexMode) {
case 'Z' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4532
throw new RegexUnsupportedException("string anchor \\Z is not supported")
case 's' | 'S' =>
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u000b'))
chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
case _ =>
regex
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,11 @@ class CudfRegexTranspiler(mode: RegexMode) {
case 'Z' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4532
throw new RegexUnsupportedException("string anchor \\Z is not supported")
case 's' | 'S' =>
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u000b'))
chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
case _ =>
regex
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}
}

test("whitespace boundaries - replace") {
assertCpuGpuMatchesRegexpReplace(
Seq("\\s", "\\S"),
Seq("\u001eTEST"))
}

test("match literal $ - find") {
assertCpuGpuMatchesRegexpFind(
Seq("\\$", "\\$[0-9]"),
Expand Down Expand Up @@ -292,7 +298,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\nBsdwSDWzZ"
private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\n\f\u000bBsdwSDWzZ"

private val REGEXP_LIMITED_CHARS_FIND = REGEXP_LIMITED_CHARS_COMMON

Expand Down

0 comments on commit 49a29a3

Please sign in to comment.