diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 5151484fcca..46fcba58074 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -771,6 +771,23 @@ def test_regexp_extract_idx_0(): 'regexp_extract(a, "^([a-d]*)[0-9]*([a-d]*)\\z", 0)'), conf=_regexp_conf) +def test_regexp_whitespace(): + gen = mk_str_gen('\u001e[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f[0-9]{0,10}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'rlike(a, "\\s{2}")', + 'rlike(a, "\\s{3}")', + 'rlike(a, "[abcd]+\\s+[0-9]+")', + 'rlike(a, "\\S{3}")', + 'rlike(a, "[abcd]+\\s+\\S{2,3}")', + 'regexp_extract(a, "([a-d]+)([0-9\\s]+)([a-d]+)", 2)', + 'regexp_extract(a, "([a-d]+)(\\S+)([0-9]+)", 2)', + 'regexp_extract(a, "([a-d]+)(\\S+)([0-9]+)", 3)', + 'regexp_replace(a, "(\\s+)", "@")', + 'regexp_replace(a, "(\\S+)", "#")', + ), + conf=_regexp_conf) + def test_rlike(): gen = mk_str_gen('[abcd]{1,3}') assert_gpu_and_cpu_are_equal_collect( diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 705faa446a8..9025d7120e5 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -584,9 +584,6 @@ class CudfRegexTranspiler(mode: RegexMode) { case 'b' | 'B' => // see https://github.com/NVIDIA/spark-rapids/issues/4517 throw new RegexUnsupportedException("word boundaries are not supported") - case 's' | 'S' => - // see https://github.com/NVIDIA/spark-rapids/issues/4528 - throw new RegexUnsupportedException("whitespace classes are not supported") case 'A' if mode == RegexSplitMode => throw new RegexUnsupportedException("string anchor \\A is not supported in split mode") case 'Z' if mode == RegexSplitMode => @@ -604,6 +601,11 @@ class CudfRegexTranspiler(mode: RegexMode) { case 'Z' => // see https://github.com/NVIDIA/spark-rapids/issues/4532 throw new RegexUnsupportedException("string anchor \\Z is not supported") + case 's' | 'S' => + val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer( + RegexChar(' '), RegexChar('\u000b')) + chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped) + RegexCharacterClass(negated = ch.isUpper, characters = chars) case _ => regex } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 2589bf5c897..4cf88e2ae12 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -584,9 +584,6 @@ class CudfRegexTranspiler(mode: RegexMode) { case 'b' | 'B' => // see https://github.com/NVIDIA/spark-rapids/issues/4517 throw new RegexUnsupportedException("word boundaries are not supported") - case 's' | 'S' => - // see https://github.com/NVIDIA/spark-rapids/issues/4528 - throw new RegexUnsupportedException("whitespace classes are not supported") case 'A' if mode == RegexSplitMode => throw new RegexUnsupportedException("string anchor \\A is not supported in split mode") case 'Z' if mode == RegexSplitMode => @@ -604,6 +601,11 @@ class CudfRegexTranspiler(mode: RegexMode) { case 'Z' => // see https://github.com/NVIDIA/spark-rapids/issues/4532 throw new RegexUnsupportedException("string anchor \\Z is not supported") + case 's' | 'S' => + val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer( + RegexChar(' '), RegexChar('\u000b')) + chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped) + RegexCharacterClass(negated = ch.isUpper, characters = chars) case _ => regex } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index f1e2c0ca3ff..a0c26a1636c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -221,6 +221,12 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } + test("whitespace boundaries - replace") { + assertCpuGpuMatchesRegexpReplace( + Seq("\\s", "\\S"), + Seq("\u001eTEST")) + } + test("match literal $ - find") { assertCpuGpuMatchesRegexpFind( Seq("\\$", "\\$[0-9]"), @@ -292,7 +298,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assertCpuGpuMatchesRegexpFind(patterns, inputs) } - private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\nBsdwSDWzZ" + private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\n\f\u000bBsdwSDWzZ" private val REGEXP_LIMITED_CHARS_FIND = REGEXP_LIMITED_CHARS_COMMON