From 49a29a340e687f294b3de6769e8e05fbf62bb46d Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 1 Apr 2022 18:14:21 +0800 Subject: [PATCH] update --- integration_tests/src/main/python/string_test.py | 10 +++++----- .../scala/com/nvidia/spark/rapids/RegexParser.scala | 8 +++++--- .../scala/com/nvidia/spark/rapids/RegexParser.scala | 5 +++++ .../rapids/RegularExpressionTranspilerSuite.scala | 8 +++++++- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index c0e883d97fd..46fcba58074 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -772,13 +772,13 @@ def test_regexp_extract_idx_0(): conf=_regexp_conf) def test_regexp_whitespace(): - gen = mk_str_gen('[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f{1,3}') + gen = mk_str_gen('\u001e[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f[0-9]{0,10}') assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).selectExpr( - 'regexp(a, "\\s{2}")', - 'regexp(a, "\\s{3}")', - 'regexp(a, "[abcd]+\\s+[0-9]+")', - 'regexp(a, "\\S{3}")', + 'rlike(a, "\\s{2}")', + 'rlike(a, "\\s{3}")', + 'rlike(a, "[abcd]+\\s+[0-9]+")', + 'rlike(a, "\\S{3}")', 'rlike(a, "[abcd]+\\s+\\S{2,3}")', 'regexp_extract(a, "([a-d]+)([0-9\\s]+)([a-d]+)", 2)', 'regexp_extract(a, "([a-d]+)(\\S+)([0-9]+)", 2)', diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 705faa446a8..9025d7120e5 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -584,9 +584,6 @@ class CudfRegexTranspiler(mode: RegexMode) { case 'b' | 'B' => // see https://github.com/NVIDIA/spark-rapids/issues/4517 throw new RegexUnsupportedException("word boundaries are not supported") - case 's' | 'S' => - // see https://github.com/NVIDIA/spark-rapids/issues/4528 - throw new RegexUnsupportedException("whitespace classes are not supported") case 'A' if mode == RegexSplitMode => throw new RegexUnsupportedException("string anchor \\A is not supported in split mode") case 'Z' if mode == RegexSplitMode => @@ -604,6 +601,11 @@ class CudfRegexTranspiler(mode: RegexMode) { case 'Z' => // see https://github.com/NVIDIA/spark-rapids/issues/4532 throw new RegexUnsupportedException("string anchor \\Z is not supported") + case 's' | 'S' => + val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer( + RegexChar(' '), RegexChar('\u000b')) + chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped) + RegexCharacterClass(negated = ch.isUpper, characters = chars) case _ => regex } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index b5f07f42cad..4cf88e2ae12 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -601,6 +601,11 @@ class CudfRegexTranspiler(mode: RegexMode) { case 'Z' => // see https://github.com/NVIDIA/spark-rapids/issues/4532 throw new RegexUnsupportedException("string anchor \\Z is not supported") + case 's' | 'S' => + val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer( + RegexChar(' '), RegexChar('\u000b')) + chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped) + RegexCharacterClass(negated = ch.isUpper, characters = chars) case _ => regex } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index f1e2c0ca3ff..a0c26a1636c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -221,6 +221,12 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } + test("whitespace boundaries - replace") { + assertCpuGpuMatchesRegexpReplace( + Seq("\\s", "\\S"), + Seq("\u001eTEST")) + } + test("match literal $ - find") { assertCpuGpuMatchesRegexpFind( Seq("\\$", "\\$[0-9]"), @@ -292,7 +298,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assertCpuGpuMatchesRegexpFind(patterns, inputs) } - private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\nBsdwSDWzZ" + private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\n\f\u000bBsdwSDWzZ" private val REGEXP_LIMITED_CHARS_FIND = REGEXP_LIMITED_CHARS_COMMON