From 1cfc1730f81245623988f1fd8db548059068178f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 7 Feb 2022 13:57:47 -0700 Subject: [PATCH 01/15] Draft implementation of string_split with regexp support Signed-off-by: Andy Grove --- .../src/main/python/string_test.py | 29 +++++++- .../spark/sql/rapids/stringFunctions.scala | 17 +++-- .../RegularExpressionTranspilerSuite.scala | 69 +++++++++++++++++-- 3 files changed, 104 insertions(+), 11 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 296a1bad26c..ae4384f5fbc 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -27,13 +27,40 @@ def mk_str_gen(pattern): def test_split(): data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') - delim = '_' assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, "AB")', 'split(a, "C")', 'split(a, "_")')) +def test_split_re_negative_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, ":", -1)', + 'split(a, "o", -2)')) + +def test_split_re_zero_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, ":", 0)', + 'split(a, "o", 0)')) + +def test_split_re_postive_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, ":", 1)', + 'split(a, ":", 2)', + 'split(a, ":", 5)', + 'split(a, "o", 1)', + 'split(a, "o", 2)', + 'split(a, "o", 5)')) + @pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'), (mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'), (mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index f9f6ce4e47a..475543cebea 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1296,9 +1296,6 @@ class GpuStringSplitMeta( } else { val str = regexp.get.value.asInstanceOf[UTF8String] if (str != null) { - if (RegexParser.isRegExpString(str.toString)) { - willNotWorkOnGpu("regular expressions are not supported yet") - } if (str.numChars() == 0) { willNotWorkOnGpu("An empty regex is not supported yet") } @@ -1333,7 +1330,19 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression) override def doColumnar(str: GpuColumnVector, regex: GpuScalar, limit: GpuScalar): ColumnVector = { val intLimit = limit.getValue.asInstanceOf[Int] - str.getBase.stringSplitRecord(regex.getBase, intLimit) + val pattern = regex.getValue.asInstanceOf[UTF8String].toString + if (RegexParser.isRegExpString(pattern)) { + val cudfPattern = new CudfRegexTranspiler(replace = false).transpile(pattern) + str.getBase.stringSplitRecord( + cudfPattern, + intLimit, + true) + } else { + str.getBase.stringSplitRecord( + pattern, + intLimit, + false) + } } override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index cef033d53f5..a4c8cbab593 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -371,8 +371,52 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), replace = true) } + test("string split") { + val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]") + val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo") + for (limit <- Seq(-2, -1, 0, 1, 2, 5)) { + doStringSplitTest(patterns, data, limit) + } + } + + test("string split fuzz ") { + val (data, patterns) = generateDataAndPatterns(Some(REGEXP_LIMITED_CHARS_REPLACE), + replace = true) + for (limit <- Seq(-2, -1, 0, 1, 2, 5)) { + doStringSplitTest(patterns, data, limit) + } + } + + def doStringSplitTest(patterns: Set[String], data: Seq[String], limit: Int) { + for (pattern <- patterns) { + val cpu = cpuSplit(pattern, data, limit) + val cudfPattern = new CudfRegexTranspiler(replace = false).transpile(pattern) + val gpu = gpuSplit(cudfPattern, data, limit) + assert(cpu.length == gpu.length) + for (i <- cpu.indices) { + val cpuArray = cpu(i) + val gpuArray = gpu(i) + if (!cpuArray.sameElements(gpuArray)) { + fail(s"string_split pattern=${toReadableString(pattern)} " + + s"data=${toReadableString(data(i))} limit=$limit " + + s"\nCPU [${cpuArray.length}]: ${toReadableString(cpuArray.mkString(", "))} " + + s"\nGPU [${gpuArray.length}]: ${toReadableString(gpuArray.mkString(", "))}") + } + } + } + } + private def doAstFuzzTest(validChars: Option[String], replace: Boolean) { + val (data, patterns) = generateDataAndPatterns(validChars, replace) + if (replace) { + assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data) + } else { + assertCpuGpuMatchesRegexpFind(patterns.toSeq, data) + } + } + private def generateDataAndPatterns(validChars: Option[String], replace: Boolean) + : (Seq[String], Set[String]) = { val r = new EnhancedRandom(new Random(seed = 0L), FuzzerOptions(validChars, maxStringLen = 12)) @@ -391,12 +435,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } } - - if (replace) { - assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data) - } else { - assertCpuGpuMatchesRegexpFind(patterns.toSeq, data) - } + (data, patterns.toSet) } private def assertCpuGpuMatchesRegexpFind(javaPatterns: Seq[String], input: Seq[String]) = { @@ -494,6 +533,24 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { input.map(s => p.matcher(s).replaceAll(REPLACE_STRING)).toArray } + private def cpuSplit(pattern: String, input: Seq[String], maxSplit: Int): Seq[Array[String]] = { + input.map(s => s.split(pattern, maxSplit)) + } + + private def gpuSplit(pattern: String, input: Seq[String], maxSplit: Int): Seq[Array[String]] = { + val isRegex = RegexParser.isRegExpString(pattern) + withResource(ColumnVector.fromStrings(input: _*)) { cv => + withResource(cv.stringSplitRecord(pattern, maxSplit, isRegex)) { x => + withResource(x.copyToHost()) { hcv => + (0 until hcv.getRowCount.toInt).map(i => { + val list = hcv.getList(i) + list.toArray(new Array[String](list.size())) + }) + } + } + } + } + private def doTranspileTest(pattern: String, expected: String) { val transpiled: String = transpile(pattern, replace = false) assert(toReadableString(transpiled) === toReadableString(expected)) From 392946136f4fdfb52daa305c2359f2b64f774f01 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 7 Feb 2022 14:01:04 -0700 Subject: [PATCH 02/15] Rename maxSplit to limit and add TODO comment --- .../org/apache/spark/sql/rapids/stringFunctions.scala | 2 ++ .../spark/rapids/RegularExpressionTranspilerSuite.scala | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 475543cebea..efbd7eee37a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1335,11 +1335,13 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression) val cudfPattern = new CudfRegexTranspiler(replace = false).transpile(pattern) str.getBase.stringSplitRecord( cudfPattern, + // TODO this parameter has different meaning between Java and cuDF (limit vs maxSplit) intLimit, true) } else { str.getBase.stringSplitRecord( pattern, + // TODO this parameter has different meaning between Java and cuDF (limit vs maxSplit) intLimit, false) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index a4c8cbab593..ff5e934083c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -533,14 +533,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { input.map(s => p.matcher(s).replaceAll(REPLACE_STRING)).toArray } - private def cpuSplit(pattern: String, input: Seq[String], maxSplit: Int): Seq[Array[String]] = { - input.map(s => s.split(pattern, maxSplit)) + private def cpuSplit(pattern: String, input: Seq[String], limit: Int): Seq[Array[String]] = { + input.map(s => s.split(pattern, limit)) } - private def gpuSplit(pattern: String, input: Seq[String], maxSplit: Int): Seq[Array[String]] = { + private def gpuSplit(pattern: String, input: Seq[String], limit: Int): Seq[Array[String]] = { val isRegex = RegexParser.isRegExpString(pattern) withResource(ColumnVector.fromStrings(input: _*)) { cv => - withResource(cv.stringSplitRecord(pattern, maxSplit, isRegex)) { x => + withResource(cv.stringSplitRecord(pattern, limit, isRegex)) { x => withResource(x.copyToHost()) { hcv => (0 until hcv.getRowCount.toInt).map(i => { val list = hcv.getList(i) From 92366bdab0860e918f892174d64d2ae4ed133f77 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 7 Feb 2022 14:44:15 -0700 Subject: [PATCH 03/15] code cleanup and add separate tests for negative, zero, and positive limits --- .../spark/sql/rapids/stringFunctions.scala | 21 ++++++++----------- .../RegularExpressionTranspilerSuite.scala | 20 +++++++++++++++--- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index efbd7eee37a..7afaeb0511c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1331,20 +1331,17 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression) limit: GpuScalar): ColumnVector = { val intLimit = limit.getValue.asInstanceOf[Int] val pattern = regex.getValue.asInstanceOf[UTF8String].toString - if (RegexParser.isRegExpString(pattern)) { - val cudfPattern = new CudfRegexTranspiler(replace = false).transpile(pattern) - str.getBase.stringSplitRecord( - cudfPattern, - // TODO this parameter has different meaning between Java and cuDF (limit vs maxSplit) - intLimit, - true) + val isRegExp = RegexParser.isRegExpString(pattern) + val cudfPattern = if (isRegExp) { + new CudfRegexTranspiler(replace = false).transpile(pattern) } else { - str.getBase.stringSplitRecord( - pattern, - // TODO this parameter has different meaning between Java and cuDF (limit vs maxSplit) - intLimit, - false) + pattern } + str.getBase.stringSplitRecord( + cudfPattern, + // TODO this parameter has different meaning between Java and cuDF (limit vs maxSplit) + intLimit, + isRegExp) } override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index ff5e934083c..01f664153e2 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -371,17 +371,31 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), replace = true) } - test("string split") { + test("string split - negative limit") { val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]") val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo") - for (limit <- Seq(-2, -1, 0, 1, 2, 5)) { + for (limit <- Seq(Integer.MIN_VALUE, -2, -1)) { + doStringSplitTest(patterns, data, limit) + } + } + + test("string split - zero limit") { + val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]") + val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo") + doStringSplitTest(patterns, data, 0) + } + + test("string split - positive limit") { + val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]") + val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo") + for (limit <- Seq(1, 2, 5, Integer.MAX_VALUE)) { doStringSplitTest(patterns, data, limit) } } test("string split fuzz ") { val (data, patterns) = generateDataAndPatterns(Some(REGEXP_LIMITED_CHARS_REPLACE), - replace = true) + replace = false) for (limit <- Seq(-2, -1, 0, 1, 2, 5)) { doStringSplitTest(patterns, data, limit) } From 51d870d4874e8a46128b6cd2ef978236afd58a2d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 7 Feb 2022 15:27:50 -0700 Subject: [PATCH 04/15] fall back to CPU for limit of 0 or 1 --- .../src/main/python/string_test.py | 30 ++++++++++++++----- .../spark/sql/rapids/stringFunctions.scala | 9 ++++-- .../RegularExpressionTranspilerSuite.scala | 16 ++++------ 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index ae4384f5fbc..ee34632f7bb 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -14,7 +14,9 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, \ + assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error, \ + assert_cpu_and_gpu_are_equal_collect_with_capture from conftest import is_databricks_runtime from data_gen import * from marks import * @@ -41,23 +43,37 @@ def test_split_re_negative_limit(): 'split(a, ":", -1)', 'split(a, "o", -2)')) -def test_split_re_zero_limit(): +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_re_zero_limit_fallback(): data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ .with_special_case('boo:and:foo') - assert_gpu_and_cpu_are_equal_collect( + + assert_cpu_and_gpu_are_equal_collect_with_capture( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, ":", 0)', - 'split(a, "o", 0)')) + 'split(a, "o", 0)'), + exist_classes= "ProjectExec", + non_exist_classes= "GpuProjectExec") -def test_split_re_postive_limit(): +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_re_one_limit_fallback(): data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ .with_special_case('boo:and:foo') - assert_gpu_and_cpu_are_equal_collect( + + assert_cpu_and_gpu_are_equal_collect_with_capture( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, ":", 1)', + 'split(a, "o", 1)'), + exist_classes= "ProjectExec", + non_exist_classes= "GpuProjectExec") + +def test_split_re_positive_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, ":", 2)', 'split(a, ":", 5)', - 'split(a, "o", 1)', 'split(a, "o", 2)', 'split(a, "o", 5)')) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 7afaeb0511c..76dc858cc05 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1303,8 +1303,13 @@ class GpuStringSplitMeta( willNotWorkOnGpu("null regex is not supported yet") } } - if (!isLit(expr.limit)) { - willNotWorkOnGpu("only literal limit is supported") + extractLit(expr.limit) match { + case Some(Literal(n: Int, _)) => + if (n == 0 || n == 1) { + willNotWorkOnGpu("limit of 0 or 1 is not supported") + } + case _ => + willNotWorkOnGpu("only literal limit is supported") } } override def convertToGpu( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 01f664153e2..a7368ea80ae 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -371,7 +371,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), replace = true) } - test("string split - negative limit") { + test("string split - limit < 0") { val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]") val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo") for (limit <- Seq(Integer.MIN_VALUE, -2, -1)) { @@ -379,24 +379,18 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } - test("string split - zero limit") { + test("string split - limit > 1") { val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]") val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo") - doStringSplitTest(patterns, data, 0) - } - - test("string split - positive limit") { - val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]") - val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo") - for (limit <- Seq(1, 2, 5, Integer.MAX_VALUE)) { + for (limit <- Seq(2, 5, Integer.MAX_VALUE)) { doStringSplitTest(patterns, data, limit) } } - test("string split fuzz ") { + test("string split fuzz") { val (data, patterns) = generateDataAndPatterns(Some(REGEXP_LIMITED_CHARS_REPLACE), replace = false) - for (limit <- Seq(-2, -1, 0, 1, 2, 5)) { + for (limit <- Seq(-2, -1, 2, 5)) { doStringSplitTest(patterns, data, limit) } } From 9c09603870789943d873bc4a215df0680158a208 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 7 Feb 2022 16:13:38 -0700 Subject: [PATCH 05/15] fall back to CPU for split on regex containing string or line anchors --- docs/compatibility.md | 1 + .../shims/v2/GpuRegExpReplaceMeta.scala | 2 +- .../com/nvidia/spark/rapids/RegexParser.scala | 19 +++- .../spark/sql/rapids/stringFunctions.scala | 6 +- .../RegularExpressionTranspilerSuite.scala | 92 +++++++++---------- 5 files changed, 67 insertions(+), 53 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 8b1dbec1786..7d1419836ba 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -518,6 +518,7 @@ Here are some examples of regular expression patterns that are not supported on - Line anchor `$` - String anchor `\Z` - String anchor `\z` is not supported by `regexp_replace` +- Line and string anchors are not supported by `string_split` - Non-digit character class `\D` - Non-word character class `\W` - Word and non-word boundaries, `\b` and `\B` diff --git a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala index 2460e3d2e5f..8e4d60a536f 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala @@ -39,7 +39,7 @@ class GpuRegExpReplaceMeta( // use GpuStringReplace } else { try { - pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString)) + pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 002c8b3f04b..c9d6909798e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -428,6 +428,11 @@ object RegexParser { } } +sealed trait RegexMode +object RegexFindMode extends RegexMode +object RegexReplaceMode extends RegexMode +object RegexSplitMode extends RegexMode + /** * Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception * if this is not possible. @@ -435,7 +440,7 @@ object RegexParser { * @param replace True if performing a replacement (regexp_replace), false * if matching only (rlike) */ -class CudfRegexTranspiler(replace: Boolean) { +class CudfRegexTranspiler(mode: RegexMode) { // cuDF throws a "nothing to repeat" exception for many of the edge cases that are // rejected by the transpiler @@ -467,6 +472,8 @@ class CudfRegexTranspiler(replace: Boolean) { case '$' => // see https://github.com/NVIDIA/spark-rapids/issues/4533 throw new RegexUnsupportedException("line anchor $ is not supported") + case '^' if mode == RegexSplitMode => + throw new RegexUnsupportedException("line anchor ^ is not supported in split mode") case _ => regex } @@ -494,8 +501,14 @@ class CudfRegexTranspiler(replace: Boolean) { case 's' | 'S' => // see https://github.com/NVIDIA/spark-rapids/issues/4528 throw new RegexUnsupportedException("whitespace classes are not supported") + case 'A' if mode == RegexSplitMode => + throw new RegexUnsupportedException("string anchor \\A is not supported in split mode") + case 'Z' if mode == RegexSplitMode => + throw new RegexUnsupportedException("string anchor \\Z is not supported in split mode") + case 'z' if mode == RegexSplitMode => + throw new RegexUnsupportedException("string anchor \\z is not supported in split mode") case 'z' => - if (replace) { + if (mode == RegexReplaceMode) { // see https://github.com/NVIDIA/spark-rapids/issues/4425 throw new RegexUnsupportedException( "string anchor \\z is not supported in replace mode") @@ -590,7 +603,7 @@ class CudfRegexTranspiler(replace: Boolean) { RegexSequence(parts.map(rewrite)) case RegexRepetition(base, quantifier) => (base, quantifier) match { - case (_, SimpleQuantifier(ch)) if replace && "?*".contains(ch) => + case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) => // example: pattern " ?", input "] b[", replace with "X": // java: X]XXbX[X // cuDF: XXXX] b[ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 76dc858cc05..18e2cb9aca6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -799,7 +799,7 @@ class GpuRLikeMeta( case Literal(str: UTF8String, DataTypes.StringType) if str != null => try { // verify that we support this regex and can transpile it to cuDF format - pattern = Some(new CudfRegexTranspiler(replace = false).transpile(str.toString)) + pattern = Some(new CudfRegexTranspiler(RegexFindMode).transpile(str.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) @@ -946,7 +946,7 @@ class GpuRegExpExtractMeta( try { val javaRegexpPattern = str.toString // verify that we support this regex and can transpile it to cuDF format - val cudfRegexPattern = new CudfRegexTranspiler(replace = false) + val cudfRegexPattern = new CudfRegexTranspiler(RegexFindMode) .transpile(javaRegexpPattern) pattern = Some(cudfRegexPattern) numGroups = countGroups(new RegexParser(javaRegexpPattern).parse()) @@ -1338,7 +1338,7 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression) val pattern = regex.getValue.asInstanceOf[UTF8String].toString val isRegExp = RegexParser.isRegExpString(pattern) val cudfPattern = if (isRegExp) { - new CudfRegexTranspiler(replace = false).transpile(pattern) + new CudfRegexTranspiler(RegexSplitMode).transpile(pattern) } else { pattern } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index a7368ea80ae..dae7dbed49f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -49,9 +49,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { for (pattern <- cudfInvalidPatterns) { // check that this is valid in Java Pattern.compile(pattern) - Seq(true, false).foreach { replace => + Seq(RegexFindMode, RegexReplaceMode).foreach { mode => try { - if (replace) { + if (mode == RegexReplaceMode) { gpuReplace(pattern, inputs) } else { gpuContains(pattern, inputs) @@ -61,9 +61,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { case e: CudfException => // expected, now make sure that the transpiler can detect this try { - transpile(pattern, replace) + transpile(pattern, mode) fail( - s"transpiler failed to detect invalid cuDF pattern (replace=$replace): $pattern", e) + s"transpiler failed to detect invalid cuDF pattern (mode=$mode): $pattern", e) } catch { case _: RegexUnsupportedException => // expected @@ -76,7 +76,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support choice with nothing to repeat") { val patterns = Seq("b+|^\t") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, "nothing to repeat") + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") ) } @@ -100,14 +100,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support possessive quantifier") { val patterns = Seq("a*+", "a|(a?|a*+)") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, "nothing to repeat") + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") ) } test("cuDF does not support empty sequence") { val patterns = Seq("", "a|", "()") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, "empty sequence not supported") + assertUnsupported(pattern, RegexFindMode, "empty sequence not supported") ) } @@ -115,14 +115,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // note that we could choose to transpile and escape the '{' and '}' characters val patterns = Seq("{1,2}", "{1,}", "{1}", "{2,1}") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, "nothing to repeat") + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") ) } test("cuDF does not support OR at BOL / EOL") { val patterns = Seq("$|a", "^|a") patterns.foreach(pattern => { - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "sequences that only contain '^' or '$' are not supported") }) } @@ -130,7 +130,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support null in pattern") { val patterns = Seq("\u0000", "a\u0000b", "a(\u0000)b", "a[a-b][\u0000]") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "cuDF does not support null characters in regular expressions")) } @@ -138,7 +138,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // see https://github.com/NVIDIA/spark-rapids/issues/4486 val patterns = Seq(raw"\xA9", raw"\x00A9", raw"\x10FFFF") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "cuDF does not support hex digits consistently with Spark")) } @@ -146,7 +146,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // see https://github.com/NVIDIA/spark-rapids/issues/4288 val patterns = Seq(raw"\07", raw"\077", raw"\0377") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "cuDF does not support octal digits consistently with Spark")) } @@ -157,8 +157,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("string anchor \\Z fall back to CPU") { - for (replace <- Seq(true, false)) { - assertUnsupported("\\Z", replace, "string anchor \\Z is not supported") + for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { + assertUnsupported("\\Z", mode, "string anchor \\Z is not supported") } } @@ -169,8 +169,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("line anchor $ fall back to CPU") { - for (replace <- Seq(true, false)) { - assertUnsupported("a$b", replace, "line anchor $ is not supported") + for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { + assertUnsupported("a$b", mode, "line anchor $ is not supported") } } @@ -205,7 +205,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("transpile character class unescaped range symbol") { val patterns = Seq("a[-b]", "a[+-]", "a[-+]", "a[-]", "a[^-]") val expected = Seq(raw"a[\-b]", raw"a[+\-]", raw"a[\-+]", raw"a[\-]", "a(?:[\r\n]|[^\\-])") - val transpiler = new CudfRegexTranspiler(replace=false) + val transpiler = new CudfRegexTranspiler(RegexFindMode) val transpiled = patterns.map(transpiler.transpile) assert(transpiled === expected) } @@ -259,15 +259,15 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("fall back to CPU for \\D") { // see https://github.com/NVIDIA/spark-rapids/issues/4475 - for (replace <- Seq(true, false)) { - assertUnsupported("\\D", replace, "non-digit class \\D is not supported") + for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { + assertUnsupported("\\D", mode, "non-digit class \\D is not supported") } } test("fall back to CPU for \\W") { // see https://github.com/NVIDIA/spark-rapids/issues/4475 - for (replace <- Seq(true, false)) { - assertUnsupported("\\W", replace, "non-word class \\W is not supported") + for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { + assertUnsupported("\\W", mode, "non-word class \\W is not supported") } } @@ -283,7 +283,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // testing with this limited set of characters finds issues much // faster than using the full ASCII set // CR and LF has been excluded due to known issues - doFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), replace = false) + doFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), RegexFindMode) } test("compare CPU and GPU: regexp replace simple regular expressions") { @@ -301,7 +301,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support some uses of line anchors in regexp_replace") { Seq("^$", "^", "$", "(^)($)", "(((^^^)))$", "^*", "$*", "^+", "$+", "^|$", "^^|$$").foreach( pattern => - assertUnsupported(pattern, replace = true, + assertUnsupported(pattern, RegexReplaceMode, "sequences that only contain '^' or '$' are not supported") ) } @@ -317,27 +317,27 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // testing with this limited set of characters finds issues much // faster than using the full ASCII set // LF has been excluded due to known issues - doFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), replace = true) + doFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), RegexReplaceMode) } test("compare CPU and GPU: regexp find fuzz test printable ASCII chars plus CR, LF, and TAB") { // CR and LF has been excluded due to known issues - doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\r\n\t"), replace = false) + doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\r\n\t"), RegexFindMode) } test("compare CPU and GPU: fuzz test ASCII chars") { // LF has been excluded due to known issues val chars = (0x00 to 0x7F) .map(_.toChar) - doFuzzTest(Some(chars.mkString), replace = true) + doFuzzTest(Some(chars.mkString), RegexReplaceMode) } test("compare CPU and GPU: regexp find fuzz test all chars") { // this test cannot be enabled until we support CR and LF - doFuzzTest(None, replace = false) + doFuzzTest(None, RegexFindMode) } - private def doFuzzTest(validChars: Option[String], replace: Boolean) { + private def doFuzzTest(validChars: Option[String], mode: RegexMode) { val r = new EnhancedRandom(new Random(seed = 0L), options = FuzzerOptions(validChars, maxStringLen = 12)) @@ -350,13 +350,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { while (patterns.size < 5000) { val pattern = r.nextString() if (!patterns.contains(pattern)) { - if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, replace)).isSuccess) { + if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, mode)).isSuccess) { patterns += pattern } } } - if (replace) { + if (mode == RegexReplaceMode) { assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data) } else { assertCpuGpuMatchesRegexpFind(patterns.toSeq, data) @@ -364,11 +364,11 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("AST fuzz test - regexp_find") { - doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), replace = false) + doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), RegexFindMode) } test("AST fuzz test - regexp_replace") { - doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), replace = true) + doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), RegexReplaceMode) } test("string split - limit < 0") { @@ -389,7 +389,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("string split fuzz") { val (data, patterns) = generateDataAndPatterns(Some(REGEXP_LIMITED_CHARS_REPLACE), - replace = false) + RegexSplitMode) for (limit <- Seq(-2, -1, 2, 5)) { doStringSplitTest(patterns, data, limit) } @@ -398,7 +398,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { def doStringSplitTest(patterns: Set[String], data: Seq[String], limit: Int) { for (pattern <- patterns) { val cpu = cpuSplit(pattern, data, limit) - val cudfPattern = new CudfRegexTranspiler(replace = false).transpile(pattern) + val cudfPattern = new CudfRegexTranspiler(RegexSplitMode).transpile(pattern) val gpu = gpuSplit(cudfPattern, data, limit) assert(cpu.length == gpu.length) for (i <- cpu.indices) { @@ -414,16 +414,16 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } } - private def doAstFuzzTest(validChars: Option[String], replace: Boolean) { - val (data, patterns) = generateDataAndPatterns(validChars, replace) - if (replace) { + private def doAstFuzzTest(validChars: Option[String], mode: RegexMode) { + val (data, patterns) = generateDataAndPatterns(validChars, mode) + if (mode == RegexReplaceMode) { assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data) } else { assertCpuGpuMatchesRegexpFind(patterns.toSeq, data) } } - private def generateDataAndPatterns(validChars: Option[String], replace: Boolean) + private def generateDataAndPatterns(validChars: Option[String], mode: RegexMode) : (Seq[String], Set[String]) = { val r = new EnhancedRandom(new Random(seed = 0L), FuzzerOptions(validChars, maxStringLen = 12)) @@ -438,7 +438,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { while (patterns.size < 5000) { val pattern = fuzzer.generate(0).toRegexString if (!patterns.contains(pattern)) { - if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, replace)).isSuccess) { + if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, mode)).isSuccess) { patterns += pattern } } @@ -449,7 +449,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { private def assertCpuGpuMatchesRegexpFind(javaPatterns: Seq[String], input: Seq[String]) = { for (javaPattern <- javaPatterns) { val cpu = cpuContains(javaPattern, input) - val cudfPattern = new CudfRegexTranspiler(replace = false).transpile(javaPattern) + val cudfPattern = new CudfRegexTranspiler(RegexFindMode).transpile(javaPattern) val gpu = try { gpuContains(cudfPattern, input) } catch { @@ -472,7 +472,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { input: Seq[String]) = { for (javaPattern <- javaPatterns) { val cpu = cpuReplace(javaPattern, input) - val cudfPattern = new CudfRegexTranspiler(replace = true).transpile(javaPattern) + val cudfPattern = new CudfRegexTranspiler(RegexReplaceMode).transpile(javaPattern) val gpu = try { gpuReplace(cudfPattern, input) } catch { @@ -560,17 +560,17 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } private def doTranspileTest(pattern: String, expected: String) { - val transpiled: String = transpile(pattern, replace = false) + val transpiled: String = transpile(pattern, RegexFindMode) assert(toReadableString(transpiled) === toReadableString(expected)) } - private def transpile(pattern: String, replace: Boolean): String = { - new CudfRegexTranspiler(replace).transpile(pattern) + private def transpile(pattern: String, mode: RegexMode): String = { + new CudfRegexTranspiler(mode).transpile(pattern) } - private def assertUnsupported(pattern: String, replace: Boolean, message: String): Unit = { + private def assertUnsupported(pattern: String, mode: RegexMode, message: String): Unit = { val e = intercept[RegexUnsupportedException] { - transpile(pattern, replace) + transpile(pattern, mode) } assert(e.getMessage.startsWith(message), pattern) } From 55567177419616b58bd2693362f2075fa56c9806 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 8 Feb 2022 08:52:09 -0700 Subject: [PATCH 06/15] move some logic from GpuStringSplit to GpuStringSplitMeta --- .../spark/sql/rapids/stringFunctions.scala | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 18e2cb9aca6..c3be6928140 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1289,6 +1289,9 @@ class GpuStringSplitMeta( extends TernaryExprMeta[StringSplit](expr, conf, parent, rule) { import GpuOverrides._ + private var pattern: String = _ + private var isRegExp = false + override def tagExprForGpu(): Unit = { val regexp = extractLit(expr.regex) if (regexp.isEmpty) { @@ -1299,6 +1302,17 @@ class GpuStringSplitMeta( if (str.numChars() == 0) { willNotWorkOnGpu("An empty regex is not supported yet") } + isRegExp = RegexParser.isRegExpString(str.toString) + if (isRegExp) { + try { + pattern = new CudfRegexTranspiler(RegexSplitMode).transpile(str.toString) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + } else { + pattern = str.toString + } } else { willNotWorkOnGpu("null regex is not supported yet") } @@ -1316,11 +1330,12 @@ class GpuStringSplitMeta( str: Expression, regexp: Expression, limit: Expression): GpuExpression = - GpuStringSplit(str, regexp, limit) + GpuStringSplit(str, regexp, limit, isRegExp, pattern) } -case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression) - extends GpuTernaryExpression with ImplicitCastInputTypes { +case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression, + isRegExp: Boolean, pattern: String) + extends GpuTernaryExpression with ImplicitCastInputTypes { override def dataType: DataType = ArrayType(StringType) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -1328,25 +1343,12 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression) override def second: Expression = regex override def third: Expression = limit - def this(exp: Expression, regex: Expression) = this(exp, regex, GpuLiteral(-1, IntegerType)) - override def prettyName: String = "split" override def doColumnar(str: GpuColumnVector, regex: GpuScalar, limit: GpuScalar): ColumnVector = { val intLimit = limit.getValue.asInstanceOf[Int] - val pattern = regex.getValue.asInstanceOf[UTF8String].toString - val isRegExp = RegexParser.isRegExpString(pattern) - val cudfPattern = if (isRegExp) { - new CudfRegexTranspiler(RegexSplitMode).transpile(pattern) - } else { - pattern - } - str.getBase.stringSplitRecord( - cudfPattern, - // TODO this parameter has different meaning between Java and cuDF (limit vs maxSplit) - intLimit, - isRegExp) + str.getBase.stringSplitRecord(pattern, intLimit, isRegExp) } override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar, From 14b873cfd7e302985b9649e51f4120cfe8d4a4fb Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 8 Feb 2022 09:22:15 -0700 Subject: [PATCH 07/15] Additional tests --- .../src/main/python/string_test.py | 50 +++++++++++++++++-- .../spark/sql/rapids/stringFunctions.scala | 1 + 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index ee34632f7bb..412b072f7a3 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -27,7 +27,7 @@ def mk_str_gen(pattern): return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}') -def test_split(): +def test_split_no_limit(): data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr( @@ -35,6 +35,42 @@ def test_split(): 'split(a, "C")', 'split(a, "_")')) +def test_split_negative_limit(): + data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "AB", -1)', + 'split(a, "C", -2)', + 'split(a, "_", -999)')) + +# https://github.com/NVIDIA/spark-rapids/issues/4720 +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_zero_limit_fallback(): + data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "AB", 0)'), + exist_classes= "ProjectExec", + non_exist_classes= "GpuProjectExec") + +# https://github.com/NVIDIA/spark-rapids/issues/4720 +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_one_limit_fallback(): + data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "AB", 1)'), + exist_classes= "ProjectExec", + non_exist_classes= "GpuProjectExec") + +def test_split_positive_limit(): + data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "AB", 2)', + 'split(a, "C", 3)', + 'split(a, "_", 999)')) + def test_split_re_negative_limit(): data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ .with_special_case('boo:and:foo') @@ -43,11 +79,11 @@ def test_split_re_negative_limit(): 'split(a, ":", -1)', 'split(a, "o", -2)')) +# https://github.com/NVIDIA/spark-rapids/issues/4720 @allow_non_gpu('ProjectExec', 'StringSplit') def test_split_re_zero_limit_fallback(): data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ .with_special_case('boo:and:foo') - assert_cpu_and_gpu_are_equal_collect_with_capture( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, ":", 0)', @@ -55,11 +91,11 @@ def test_split_re_zero_limit_fallback(): exist_classes= "ProjectExec", non_exist_classes= "GpuProjectExec") +# https://github.com/NVIDIA/spark-rapids/issues/4720 @allow_non_gpu('ProjectExec', 'StringSplit') def test_split_re_one_limit_fallback(): data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ .with_special_case('boo:and:foo') - assert_cpu_and_gpu_are_equal_collect_with_capture( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, ":", 1)', @@ -77,6 +113,14 @@ def test_split_re_positive_limit(): 'split(a, "o", 2)', 'split(a, "o", 5)')) +def test_split_re_no_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, ":")', + 'split(a, "o")')) + @pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'), (mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'), (mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index c3be6928140..1eab67b9e80 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1320,6 +1320,7 @@ class GpuStringSplitMeta( extractLit(expr.limit) match { case Some(Literal(n: Int, _)) => if (n == 0 || n == 1) { + // https://github.com/NVIDIA/spark-rapids/issues/4720 willNotWorkOnGpu("limit of 0 or 1 is not supported") } case _ => From 5cc7ccd3334955de8cd7ddcfebfcde5096add3fb Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 8 Feb 2022 16:07:50 -0700 Subject: [PATCH 08/15] update shims --- .../nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala | 4 ++-- .../nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala | 4 ++-- .../nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 39e09b99eca..56c2059a1ae 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2 import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String @@ -39,7 +39,7 @@ class GpuRegExpReplaceMeta( // use GpuStringReplace } else { try { - pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString)) + pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) diff --git a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index b199a4d3a91..4268b60453b 100644 --- a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2 import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String @@ -39,7 +39,7 @@ class GpuRegExpReplaceMeta( // use GpuStringReplace } else { try { - pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString)) + pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index b199a4d3a91..4268b60453b 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2 import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String @@ -39,7 +39,7 @@ class GpuRegExpReplaceMeta( // use GpuStringReplace } else { try { - pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString)) + pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) From 19918c2b0f0d2447573026ae3c33ec0331fce7db Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 8 Feb 2022 16:17:29 -0700 Subject: [PATCH 09/15] check that expression has been tagged --- .../apache/spark/sql/rapids/stringFunctions.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 1eab67b9e80..bdc62f0ec99 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1289,7 +1289,7 @@ class GpuStringSplitMeta( extends TernaryExprMeta[StringSplit](expr, conf, parent, rule) { import GpuOverrides._ - private var pattern: String = _ + private var pattern: Option[String] = None private var isRegExp = false override def tagExprForGpu(): Unit = { @@ -1305,13 +1305,13 @@ class GpuStringSplitMeta( isRegExp = RegexParser.isRegExpString(str.toString) if (isRegExp) { try { - pattern = new CudfRegexTranspiler(RegexSplitMode).transpile(str.toString) + pattern = Some(new CudfRegexTranspiler(RegexSplitMode).transpile(str.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) } } else { - pattern = str.toString + pattern = Some(str.toString) } } else { willNotWorkOnGpu("null regex is not supported yet") @@ -1330,8 +1330,10 @@ class GpuStringSplitMeta( override def convertToGpu( str: Expression, regexp: Expression, - limit: Expression): GpuExpression = - GpuStringSplit(str, regexp, limit, isRegExp, pattern) + limit: Expression): GpuExpression = { + GpuStringSplit(str, regexp, limit, isRegExp, pattern.getOrElse( + throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))) + } } case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression, From 481b3561f471e4a4bfb56666030583325f14d0ff Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 9 Feb 2022 10:23:59 -0700 Subject: [PATCH 10/15] update split_re tests to actually use regexp rather than simple strings --- .../src/main/python/string_test.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 412b072f7a3..0b8ad1d543f 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -76,8 +76,13 @@ def test_split_re_negative_limit(): .with_special_case('boo:and:foo') assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, ":", -1)', - 'split(a, "o", -2)')) + 'split(a, "[:]", -1)', + 'split(a, "[o:]", -1)', + 'split(a, "[^:]", -1)', + 'split(a, "[^o]", -1)', + 'split(a, "[o]{1,2}", -1)', + 'split(a, "[bf]", -1)', + 'split(a, "[o]", -2)')) # https://github.com/NVIDIA/spark-rapids/issues/4720 @allow_non_gpu('ProjectExec', 'StringSplit') @@ -86,8 +91,9 @@ def test_split_re_zero_limit_fallback(): .with_special_case('boo:and:foo') assert_cpu_and_gpu_are_equal_collect_with_capture( lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, ":", 0)', - 'split(a, "o", 0)'), + 'split(a, "[:]", 0)', + 'split(a, "[o:]", 0)', + 'split(a, "[o]", 0)'), exist_classes= "ProjectExec", non_exist_classes= "GpuProjectExec") @@ -98,8 +104,9 @@ def test_split_re_one_limit_fallback(): .with_special_case('boo:and:foo') assert_cpu_and_gpu_are_equal_collect_with_capture( lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, ":", 1)', - 'split(a, "o", 1)'), + 'split(a, "[:]", 1)', + 'split(a, "[o:]", 1)', + 'split(a, "[o]", 1)'), exist_classes= "ProjectExec", non_exist_classes= "GpuProjectExec") @@ -108,18 +115,26 @@ def test_split_re_positive_limit(): .with_special_case('boo:and:foo') assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, ":", 2)', - 'split(a, ":", 5)', - 'split(a, "o", 2)', - 'split(a, "o", 5)')) + 'split(a, "[:]", 2)', + 'split(a, "[o:]", 5)', + 'split(a, "[^:]", 2)', + 'split(a, "[^o]", 55)', + 'split(a, "[o]{1,2}", 999)', + 'split(a, "[bf]", 2)', + 'split(a, "[o]", 5)')) def test_split_re_no_limit(): data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ .with_special_case('boo:and:foo') assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr( - 'split(a, ":")', - 'split(a, "o")')) + 'split(a, "[:]")', + 'split(a, "[o:]")', + 'split(a, "[^:]")', + 'split(a, "[^o]")', + 'split(a, "[o]{1,2}")', + 'split(a, "[bf]")', + 'split(a, "[o]")')) @pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'), (mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'), From 0ef483079a3b397766abd5e2ce0f501bb695b036 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 10 Feb 2022 14:50:02 -0700 Subject: [PATCH 11/15] fix merge issue --- .../spark/rapids/RegularExpressionTranspilerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 3a0a55cc8ab..f050eeddfdd 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -137,14 +137,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support octal digits 0o177 < n <= 0o377") { val patterns = Seq(raw"\0200", raw"\0377") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "cuDF does not support octal digits 0o177 < n <= 0o377")) } test("cuDF does not support octal digits in character classes") { val patterns = Seq(raw"[\02]", raw"[\012]", raw"[\0177]") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "cuDF does not support octal digits in character classes" ) ) From e396c69e7e1a2e79101f967152335983b9919fe8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 10 Feb 2022 14:51:42 -0700 Subject: [PATCH 12/15] update compatibility guide --- docs/compatibility.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/compatibility.md b/docs/compatibility.md index 06511bb937c..d7ef0005508 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -522,6 +522,7 @@ The following Apache Spark regular expression functions and expressions are supp - `regexp_extract` - `regexp_like` - `regexp_replace` +- `string_split` Regular expression evaluation on the GPU can potentially have high memory overhead and cause out-of-memory errors. To disable regular expressions on the GPU, set `spark.rapids.sql.regexp.enabled=false`. From c70390f7c6c0eb5146acfe83a830179bbf5006bf Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 11 Feb 2022 08:40:44 -0700 Subject: [PATCH 13/15] fix incorrect imports in shim layer --- .../shims/v2/GpuRegExpReplaceExec.scala | 4 +- .../shims/v2/GpuRegExpReplaceExec.scala | 4 +- .../shims/v2/GpuRegExpReplaceExec.scala | 4 +- .../spark/rapids/conditionalExpressions.scala | 4 + .../apache/spark/sql/rapids/predicates.scala | 118 +++++++++++++++++- .../com/nvidia/spark/rapids/CastOpSuite.scala | 8 ++ 6 files changed, 135 insertions(+), 7 deletions(-) diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 4821fb6805e..ce13571e910 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,10 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException, TernaryExprMeta} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String diff --git a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index aab6c705cec..973948518c5 100644 --- a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,10 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index aab6c705cec..973948518c5 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,10 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala index 9781ab3e6e4..68924cfca72 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala @@ -121,6 +121,10 @@ trait GpuConditionalExpression extends ComplexTypeMergingExpression with GpuExpr // predicate boolean array results in the two T values mapping to // indices 0 and 1, respectively. + // [F, null, T, F, T] + // [0, 0, 0, 1, 1] + [ 0, 1 ] + val prefixSumExclusive = withResource(boolToInt(predicate)) { boolsAsInts => boolsAsInts.scan( ScanAggregation.sum(), diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala index a98a728a1b3..f42484d7a54 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf._ import ai.rapids.cudf.ast.BinaryOperator import com.nvidia.spark.rapids._ - import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant, Predicate} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, BooleanType, DataType, DoubleType, FloatType} +import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} trait GpuPredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { @@ -66,6 +66,122 @@ case class GpuAnd(left: Expression, right: Expression) extends CudfBinaryOperato override def binaryOp: BinaryOp = BinaryOp.NULL_LOGICAL_AND override def astOperator: Option[BinaryOperator] = Some(ast.BinaryOperator.NULL_LOGICAL_AND) + + protected def filterBatch( + tbl: Table, + pred: ColumnVector, + colTypes: Array[DataType]): ColumnarBatch = { + withResource(tbl.filter(pred)) { filteredData => + GpuColumnVector.from(filteredData, colTypes) + } + } + + def exampleTest: Unit = { + + + ColumnVector + + + } + + private def columnarEvalWithSideEffects(batch: ColumnarBatch): Any = { + val leftExpr = left.asInstanceOf[GpuExpression] + val rightExpr = right.asInstanceOf[GpuExpression] + val colTypes = GpuColumnVector.extractTypes(batch) + + withResource(GpuColumnVector.from(batch)) { tbl => + withResource(GpuExpressionsUtils.columnarEvalToColumn(leftExpr, batch)) { lhsBool => + + GpuColumnVector.debug("lhsBool", lhsBool.getBase) + + // filter to get rows where lhs was true + val rhsBool = withResource(filterBatch(tbl, lhsBool, colTypes)) { rhsBatch => + rightExpr.columnarEval(rhsBatch) + } + + GpuColumnVector.debug("rhsBool", rhsBool.getBase) + + // a AND (CAST(b as INT) + 2) > 0 + // + // a b + // true MAX_INT - 2 ... a AND b = true + // false MAX_INT - 2 + // false MAX_INT <-- currently fails + + // lhsBool = { true, false, false } + + // filtered batch: + // true MAX_INT - 2 ... a AND b = true + + // rhsBool = { true } + + // perform AND lhsBool and rhsBool + + // gather(lhsBool) = { 0, 1, 1 } + // combine lhsBool with gather => { 0 } into rhsBool + + // val rhsAdjusted = gather(lhsBool, rhsBool) + // { true, false, false } + + // lhsBool.and(rhsAdjusted) + + + + + // { 1 + + + + // lhsBool = { true, false, false } + // rhsBool = { true } -> { true, false, false } + + + + + + // TODO: verify the best way to create FALSE_EXPR + // get the inverse of leftBool + withResource(lhsBool.getBase.unaryOp(UnaryOp.NOT)) { leftInverted => + // TODO: How to evaluate RHS? on filtered batch or all batches? + val cView = withResourceIfAllowed(lhsBool) { lhs => + withResource(GpuExpressionsUtils.columnarEvalToColumn(rightExpr, batch)) { rhsBool => + withResourceIfAllowed(rightExpr.columnarEval(batch)) { rhs => + (lhs, rhs) match { + case (l: GpuColumnVector, r: GpuColumnVector) => + GpuColumnVector.from(doColumnar(l, r), dataType) + case (l: GpuScalar, r: GpuColumnVector) => + GpuColumnVector.from(doColumnar(l, r), dataType) + case (l: GpuColumnVector, r: GpuScalar) => + GpuColumnVector.from(doColumnar(l, r), dataType) + case (l: GpuScalar, r: GpuScalar) => + GpuColumnVector.from(doColumnar(batch.numRows(), l, r), dataType) + case (l, r) => + throw new UnsupportedOperationException(s"Unsupported data '($l: " + + s"${l.getClass}, $r: ${r.getClass})' for GPU binary expression.") + } + } + } + } + val flaseExpr = withResource(GpuScalar.from(false, BooleanType)) { falseScalar => + GpuColumnVector.from(falseScalar, lhsBool.getRowCount.toInt, dataType) + } + val finalReturn = leftInverted.ifElse(flaseExpr.getBase, cView.getBase) + GpuColumnVector.from(finalReturn, dataType) + } + } + } + } + + // TODO: Is this right place? or overriding the doColumnar? + override def columnarEval(batch: ColumnarBatch): Any = { + val rightExpr = right.asInstanceOf[GpuExpression] + + if (rightExpr.hasSideEffects) { + columnarEvalWithSideEffects(batch) + } else { + super.columnarEval(batch) + } + } } case class GpuOr(left: Expression, right: Expression) extends CudfBinaryOperator with Predicate { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala index ae4db3996a6..3a28cd26a47 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala @@ -950,6 +950,14 @@ class CastOpSuite extends GpuExpressionTestSuite { } } + test("") { + val lhs = ColumnVector.fromBooleans(true, false, false) + val rhs = ColumnVector.fromBooleans(true, false, false) + val expected = ColumnVector.fromBooleans(true, false, false) + + + } + test("CAST string to float - sanitize step") { val testPairs = Seq( ("\tinf", "inf"), From b7f0cbae81162fe4f4aea4f76557bb4dc5a0beae Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 11 Feb 2022 09:16:49 -0700 Subject: [PATCH 14/15] Revert "fix incorrect imports in shim layer" This reverts commit c70390f7c6c0eb5146acfe83a830179bbf5006bf. --- .../shims/v2/GpuRegExpReplaceExec.scala | 4 +- .../shims/v2/GpuRegExpReplaceExec.scala | 4 +- .../shims/v2/GpuRegExpReplaceExec.scala | 4 +- .../spark/rapids/conditionalExpressions.scala | 4 - .../apache/spark/sql/rapids/predicates.scala | 118 +----------------- .../com/nvidia/spark/rapids/CastOpSuite.scala | 8 -- 6 files changed, 7 insertions(+), 135 deletions(-) diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index ce13571e910..4821fb6805e 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,10 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException, TernaryExprMeta} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String diff --git a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 973948518c5..aab6c705cec 100644 --- a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,10 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 973948518c5..aab6c705cec 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,10 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala index 68924cfca72..9781ab3e6e4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala @@ -121,10 +121,6 @@ trait GpuConditionalExpression extends ComplexTypeMergingExpression with GpuExpr // predicate boolean array results in the two T values mapping to // indices 0 and 1, respectively. - // [F, null, T, F, T] - // [0, 0, 0, 1, 1] - [ 0, 1 ] - val prefixSumExclusive = withResource(boolToInt(predicate)) { boolsAsInts => boolsAsInts.scan( ScanAggregation.sum(), diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala index f42484d7a54..a98a728a1b3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf._ import ai.rapids.cudf.ast.BinaryOperator import com.nvidia.spark.rapids._ + import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant, Predicate} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, BooleanType, DataType, DoubleType, FloatType} -import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} trait GpuPredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { @@ -66,122 +66,6 @@ case class GpuAnd(left: Expression, right: Expression) extends CudfBinaryOperato override def binaryOp: BinaryOp = BinaryOp.NULL_LOGICAL_AND override def astOperator: Option[BinaryOperator] = Some(ast.BinaryOperator.NULL_LOGICAL_AND) - - protected def filterBatch( - tbl: Table, - pred: ColumnVector, - colTypes: Array[DataType]): ColumnarBatch = { - withResource(tbl.filter(pred)) { filteredData => - GpuColumnVector.from(filteredData, colTypes) - } - } - - def exampleTest: Unit = { - - - ColumnVector - - - } - - private def columnarEvalWithSideEffects(batch: ColumnarBatch): Any = { - val leftExpr = left.asInstanceOf[GpuExpression] - val rightExpr = right.asInstanceOf[GpuExpression] - val colTypes = GpuColumnVector.extractTypes(batch) - - withResource(GpuColumnVector.from(batch)) { tbl => - withResource(GpuExpressionsUtils.columnarEvalToColumn(leftExpr, batch)) { lhsBool => - - GpuColumnVector.debug("lhsBool", lhsBool.getBase) - - // filter to get rows where lhs was true - val rhsBool = withResource(filterBatch(tbl, lhsBool, colTypes)) { rhsBatch => - rightExpr.columnarEval(rhsBatch) - } - - GpuColumnVector.debug("rhsBool", rhsBool.getBase) - - // a AND (CAST(b as INT) + 2) > 0 - // - // a b - // true MAX_INT - 2 ... a AND b = true - // false MAX_INT - 2 - // false MAX_INT <-- currently fails - - // lhsBool = { true, false, false } - - // filtered batch: - // true MAX_INT - 2 ... a AND b = true - - // rhsBool = { true } - - // perform AND lhsBool and rhsBool - - // gather(lhsBool) = { 0, 1, 1 } - // combine lhsBool with gather => { 0 } into rhsBool - - // val rhsAdjusted = gather(lhsBool, rhsBool) - // { true, false, false } - - // lhsBool.and(rhsAdjusted) - - - - - // { 1 - - - - // lhsBool = { true, false, false } - // rhsBool = { true } -> { true, false, false } - - - - - - // TODO: verify the best way to create FALSE_EXPR - // get the inverse of leftBool - withResource(lhsBool.getBase.unaryOp(UnaryOp.NOT)) { leftInverted => - // TODO: How to evaluate RHS? on filtered batch or all batches? - val cView = withResourceIfAllowed(lhsBool) { lhs => - withResource(GpuExpressionsUtils.columnarEvalToColumn(rightExpr, batch)) { rhsBool => - withResourceIfAllowed(rightExpr.columnarEval(batch)) { rhs => - (lhs, rhs) match { - case (l: GpuColumnVector, r: GpuColumnVector) => - GpuColumnVector.from(doColumnar(l, r), dataType) - case (l: GpuScalar, r: GpuColumnVector) => - GpuColumnVector.from(doColumnar(l, r), dataType) - case (l: GpuColumnVector, r: GpuScalar) => - GpuColumnVector.from(doColumnar(l, r), dataType) - case (l: GpuScalar, r: GpuScalar) => - GpuColumnVector.from(doColumnar(batch.numRows(), l, r), dataType) - case (l, r) => - throw new UnsupportedOperationException(s"Unsupported data '($l: " + - s"${l.getClass}, $r: ${r.getClass})' for GPU binary expression.") - } - } - } - } - val flaseExpr = withResource(GpuScalar.from(false, BooleanType)) { falseScalar => - GpuColumnVector.from(falseScalar, lhsBool.getRowCount.toInt, dataType) - } - val finalReturn = leftInverted.ifElse(flaseExpr.getBase, cView.getBase) - GpuColumnVector.from(finalReturn, dataType) - } - } - } - } - - // TODO: Is this right place? or overriding the doColumnar? - override def columnarEval(batch: ColumnarBatch): Any = { - val rightExpr = right.asInstanceOf[GpuExpression] - - if (rightExpr.hasSideEffects) { - columnarEvalWithSideEffects(batch) - } else { - super.columnarEval(batch) - } - } } case class GpuOr(left: Expression, right: Expression) extends CudfBinaryOperator with Predicate { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala index 3a28cd26a47..ae4db3996a6 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala @@ -950,14 +950,6 @@ class CastOpSuite extends GpuExpressionTestSuite { } } - test("") { - val lhs = ColumnVector.fromBooleans(true, false, false) - val rhs = ColumnVector.fromBooleans(true, false, false) - val expected = ColumnVector.fromBooleans(true, false, false) - - - } - test("CAST string to float - sanitize step") { val testPairs = Seq( ("\tinf", "inf"), From 4c94170e229ffb393ac9279a10583617c0ab6acc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 11 Feb 2022 09:19:43 -0700 Subject: [PATCH 15/15] fix incorrect imports in shim layer --- .../nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala | 4 ++-- .../nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala | 4 ++-- .../nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 4821fb6805e..ce13571e910 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,10 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException, TernaryExprMeta} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String diff --git a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index aab6c705cec..973948518c5 100644 --- a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,10 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index aab6c705cec..973948518c5 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,10 +15,10 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace, RegexReplaceMode} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String