diff --git a/docs/compatibility.md b/docs/compatibility.md index 62b9b68af4b..d7ef0005508 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -522,6 +522,7 @@ The following Apache Spark regular expression functions and expressions are supp - `regexp_extract` - `regexp_like` - `regexp_replace` +- `string_split` Regular expression evaluation on the GPU can potentially have high memory overhead and cause out-of-memory errors. To disable regular expressions on the GPU, set `spark.rapids.sql.regexp.enabled=false`. @@ -535,6 +536,7 @@ Here are some examples of regular expression patterns that are not supported on - Line anchor `$` - String anchor `\Z` - String anchor `\z` is not supported by `regexp_replace` +- Line and string anchors are not supported by `string_split` - Non-digit character class `\D` - Non-word character class `\W` - Word and non-word boundaries, `\b` and `\B` diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index c3570affe98..a45f3f4cdd2 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -14,7 +14,9 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, \ + assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error, \ + assert_cpu_and_gpu_are_equal_collect_with_capture from conftest import is_databricks_runtime from data_gen import * from marks import * @@ -25,15 +27,115 @@ def mk_str_gen(pattern): return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}') -def test_split(): +def test_split_no_limit(): data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') - delim = '_' assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'split(a, "AB")', 'split(a, "C")', 'split(a, "_")')) +def test_split_negative_limit(): + data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "AB", -1)', + 'split(a, "C", -2)', + 'split(a, "_", -999)')) + +# https://github.com/NVIDIA/spark-rapids/issues/4720 +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_zero_limit_fallback(): + data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "AB", 0)'), + exist_classes= "ProjectExec", + non_exist_classes= "GpuProjectExec") + +# https://github.com/NVIDIA/spark-rapids/issues/4720 +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_one_limit_fallback(): + data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "AB", 1)'), + exist_classes= "ProjectExec", + non_exist_classes= "GpuProjectExec") + +def test_split_positive_limit(): + data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "AB", 2)', + 'split(a, "C", 3)', + 'split(a, "_", 999)')) + +def test_split_re_negative_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]", -1)', + 'split(a, "[o:]", -1)', + 'split(a, "[^:]", -1)', + 'split(a, "[^o]", -1)', + 'split(a, "[o]{1,2}", -1)', + 'split(a, "[bf]", -1)', + 'split(a, "[o]", -2)')) + +# https://github.com/NVIDIA/spark-rapids/issues/4720 +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_re_zero_limit_fallback(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]", 0)', + 'split(a, "[o:]", 0)', + 'split(a, "[o]", 0)'), + exist_classes= "ProjectExec", + non_exist_classes= "GpuProjectExec") + +# https://github.com/NVIDIA/spark-rapids/issues/4720 +@allow_non_gpu('ProjectExec', 'StringSplit') +def test_split_re_one_limit_fallback(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]", 1)', + 'split(a, "[o:]", 1)', + 'split(a, "[o]", 1)'), + exist_classes= "ProjectExec", + non_exist_classes= "GpuProjectExec") + +def test_split_re_positive_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]", 2)', + 'split(a, "[o:]", 5)', + 'split(a, "[^:]", 2)', + 'split(a, "[^o]", 55)', + 'split(a, "[o]{1,2}", 999)', + 'split(a, "[bf]", 2)', + 'split(a, "[o]", 5)')) + +def test_split_re_no_limit(): + data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \ + .with_special_case('boo:and:foo') + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "[:]")', + 'split(a, "[o:]")', + 'split(a, "[^:]")', + 'split(a, "[^o]")', + 'split(a, "[o]{1,2}")', + 'split(a, "[bf]")', + 'split(a, "[o]")')) + @pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'), (mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'), (mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn) diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 522970b8bf6..ce13571e910 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,7 +15,7 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException, TernaryExprMeta} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta( // use GpuStringReplace } else { try { - pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString)) + pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) diff --git a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala index 5b4a68270a0..e3ea7beae21 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta( // use GpuStringReplace } else { try { - pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString)) + pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) diff --git a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 8cd16be45a2..973948518c5 100644 --- a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,7 +15,7 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta( // use GpuStringReplace } else { try { - pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString)) + pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala index 8cd16be45a2..973948518c5 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala @@ -15,7 +15,7 @@ */ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException} +import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace} @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta( // use GpuStringReplace } else { try { - pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString)) + pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index b5656c30a00..09ca02032bf 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -433,6 +433,11 @@ object RegexParser { } } +sealed trait RegexMode +object RegexFindMode extends RegexMode +object RegexReplaceMode extends RegexMode +object RegexSplitMode extends RegexMode + /** * Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception * if this is not possible. @@ -440,7 +445,7 @@ object RegexParser { * @param replace True if performing a replacement (regexp_replace), false * if matching only (rlike) */ -class CudfRegexTranspiler(replace: Boolean) { +class CudfRegexTranspiler(mode: RegexMode) { // cuDF throws a "nothing to repeat" exception for many of the edge cases that are // rejected by the transpiler @@ -472,6 +477,8 @@ class CudfRegexTranspiler(replace: Boolean) { case '$' => // see https://github.com/NVIDIA/spark-rapids/issues/4533 throw new RegexUnsupportedException("line anchor $ is not supported") + case '^' if mode == RegexSplitMode => + throw new RegexUnsupportedException("line anchor ^ is not supported in split mode") case _ => regex } @@ -506,8 +513,14 @@ class CudfRegexTranspiler(replace: Boolean) { case 's' | 'S' => // see https://github.com/NVIDIA/spark-rapids/issues/4528 throw new RegexUnsupportedException("whitespace classes are not supported") + case 'A' if mode == RegexSplitMode => + throw new RegexUnsupportedException("string anchor \\A is not supported in split mode") + case 'Z' if mode == RegexSplitMode => + throw new RegexUnsupportedException("string anchor \\Z is not supported in split mode") + case 'z' if mode == RegexSplitMode => + throw new RegexUnsupportedException("string anchor \\z is not supported in split mode") case 'z' => - if (replace) { + if (mode == RegexReplaceMode) { // see https://github.com/NVIDIA/spark-rapids/issues/4425 throw new RegexUnsupportedException( "string anchor \\z is not supported in replace mode") @@ -607,7 +620,7 @@ class CudfRegexTranspiler(replace: Boolean) { RegexSequence(parts.map(rewrite)) case RegexRepetition(base, quantifier) => (base, quantifier) match { - case (_, SimpleQuantifier(ch)) if replace && "?*".contains(ch) => + case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) => // example: pattern " ?", input "] b[", replace with "X": // java: X]XXbX[X // cuDF: XXXX] b[ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 1aa7171862c..e54f34c4e76 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -833,7 +833,7 @@ class GpuRLikeMeta( case Literal(str: UTF8String, DataTypes.StringType) if str != null => try { // verify that we support this regex and can transpile it to cuDF format - pattern = Some(new CudfRegexTranspiler(replace = false).transpile(str.toString)) + pattern = Some(new CudfRegexTranspiler(RegexFindMode).transpile(str.toString)) } catch { case e: RegexUnsupportedException => willNotWorkOnGpu(e.getMessage) @@ -981,7 +981,7 @@ class GpuRegExpExtractMeta( try { val javaRegexpPattern = str.toString // verify that we support this regex and can transpile it to cuDF format - val cudfRegexPattern = new CudfRegexTranspiler(replace = false) + val cudfRegexPattern = new CudfRegexTranspiler(RegexFindMode) .transpile(javaRegexpPattern) pattern = Some(cudfRegexPattern) numGroups = countGroups(new RegexParser(javaRegexpPattern).parse()) @@ -1324,6 +1324,9 @@ class GpuStringSplitMeta( extends TernaryExprMeta[StringSplit](expr, conf, parent, rule) { import GpuOverrides._ + private var pattern: Option[String] = None + private var isRegExp = false + override def tagExprForGpu(): Unit = { val regexp = extractLit(expr.regex) if (regexp.isEmpty) { @@ -1331,29 +1334,46 @@ class GpuStringSplitMeta( } else { val str = regexp.get.value.asInstanceOf[UTF8String] if (str != null) { - if (RegexParser.isRegExpString(str.toString)) { - willNotWorkOnGpu("regular expressions are not supported yet") - } if (str.numChars() == 0) { willNotWorkOnGpu("An empty regex is not supported yet") } + isRegExp = RegexParser.isRegExpString(str.toString) + if (isRegExp) { + try { + pattern = Some(new CudfRegexTranspiler(RegexSplitMode).transpile(str.toString)) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) + } + } else { + pattern = Some(str.toString) + } } else { willNotWorkOnGpu("null regex is not supported yet") } } - if (!isLit(expr.limit)) { - willNotWorkOnGpu("only literal limit is supported") + extractLit(expr.limit) match { + case Some(Literal(n: Int, _)) => + if (n == 0 || n == 1) { + // https://github.com/NVIDIA/spark-rapids/issues/4720 + willNotWorkOnGpu("limit of 0 or 1 is not supported") + } + case _ => + willNotWorkOnGpu("only literal limit is supported") } } override def convertToGpu( str: Expression, regexp: Expression, - limit: Expression): GpuExpression = - GpuStringSplit(str, regexp, limit) + limit: Expression): GpuExpression = { + GpuStringSplit(str, regexp, limit, isRegExp, pattern.getOrElse( + throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))) + } } -case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression) - extends GpuTernaryExpression with ImplicitCastInputTypes { +case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression, + isRegExp: Boolean, pattern: String) + extends GpuTernaryExpression with ImplicitCastInputTypes { override def dataType: DataType = ArrayType(StringType, containsNull = false) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -1361,14 +1381,12 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression) override def second: Expression = regex override def third: Expression = limit - def this(exp: Expression, regex: Expression) = this(exp, regex, GpuLiteral(-1, IntegerType)) - override def prettyName: String = "split" override def doColumnar(str: GpuColumnVector, regex: GpuScalar, limit: GpuScalar): ColumnVector = { val intLimit = limit.getValue.asInstanceOf[Int] - str.getBase.stringSplitRecord(regex.getBase, intLimit) + str.getBase.stringSplitRecord(pattern, intLimit, isRegExp) } override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 1e2b216c442..f050eeddfdd 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -49,9 +49,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { for (pattern <- cudfInvalidPatterns) { // check that this is valid in Java Pattern.compile(pattern) - Seq(true, false).foreach { replace => + Seq(RegexFindMode, RegexReplaceMode).foreach { mode => try { - if (replace) { + if (mode == RegexReplaceMode) { gpuReplace(pattern, inputs) } else { gpuContains(pattern, inputs) @@ -61,9 +61,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { case e: CudfException => // expected, now make sure that the transpiler can detect this try { - transpile(pattern, replace) + transpile(pattern, mode) fail( - s"transpiler failed to detect invalid cuDF pattern (replace=$replace): $pattern", e) + s"transpiler failed to detect invalid cuDF pattern (mode=$mode): $pattern", e) } catch { case _: RegexUnsupportedException => // expected @@ -76,7 +76,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support choice with nothing to repeat") { val patterns = Seq("b+|^\t") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, "nothing to repeat") + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") ) } @@ -100,14 +100,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support possessive quantifier") { val patterns = Seq("a*+", "a|(a?|a*+)") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, "nothing to repeat") + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") ) } test("cuDF does not support empty sequence") { val patterns = Seq("", "a|", "()") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, "empty sequence not supported") + assertUnsupported(pattern, RegexFindMode, "empty sequence not supported") ) } @@ -115,14 +115,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // note that we could choose to transpile and escape the '{' and '}' characters val patterns = Seq("{1,2}", "{1,}", "{1}", "{2,1}") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, "nothing to repeat") + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") ) } test("cuDF does not support OR at BOL / EOL") { val patterns = Seq("$|a", "^|a") patterns.foreach(pattern => { - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "sequences that only contain '^' or '$' are not supported") }) } @@ -130,21 +130,21 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support null in pattern") { val patterns = Seq("\u0000", "a\u0000b", "a(\u0000)b", "a[a-b][\u0000]") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "cuDF does not support null characters in regular expressions")) } test("cuDF does not support octal digits 0o177 < n <= 0o377") { val patterns = Seq(raw"\0200", raw"\0377") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "cuDF does not support octal digits 0o177 < n <= 0o377")) } test("cuDF does not support octal digits in character classes") { val patterns = Seq(raw"[\02]", raw"[\012]", raw"[\0177]") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "cuDF does not support octal digits in character classes" ) ) @@ -154,7 +154,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // see https://github.com/NVIDIA/spark-rapids/issues/4486 val patterns = Seq(raw"\xA9", raw"\x00A9", raw"\x10FFFF") patterns.foreach(pattern => - assertUnsupported(pattern, replace = false, + assertUnsupported(pattern, RegexFindMode, "cuDF does not support hex digits consistently with Spark")) } @@ -172,8 +172,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("string anchor \\Z fall back to CPU") { - for (replace <- Seq(true, false)) { - assertUnsupported("\\Z", replace, "string anchor \\Z is not supported") + for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { + assertUnsupported("\\Z", mode, "string anchor \\Z is not supported") } } @@ -184,8 +184,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("line anchor $ fall back to CPU") { - for (replace <- Seq(true, false)) { - assertUnsupported("a$b", replace, "line anchor $ is not supported") + for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { + assertUnsupported("a$b", mode, "line anchor $ is not supported") } } @@ -220,7 +220,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("transpile character class unescaped range symbol") { val patterns = Seq("a[-b]", "a[+-]", "a[-+]", "a[-]", "a[^-]") val expected = Seq(raw"a[\-b]", raw"a[+\-]", raw"a[\-+]", raw"a[\-]", "a(?:[\r\n]|[^\\-])") - val transpiler = new CudfRegexTranspiler(replace=false) + val transpiler = new CudfRegexTranspiler(RegexFindMode) val transpiled = patterns.map(transpiler.transpile) assert(transpiled === expected) } @@ -274,15 +274,15 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("fall back to CPU for \\D") { // see https://github.com/NVIDIA/spark-rapids/issues/4475 - for (replace <- Seq(true, false)) { - assertUnsupported("\\D", replace, "non-digit class \\D is not supported") + for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { + assertUnsupported("\\D", mode, "non-digit class \\D is not supported") } } test("fall back to CPU for \\W") { // see https://github.com/NVIDIA/spark-rapids/issues/4475 - for (replace <- Seq(true, false)) { - assertUnsupported("\\W", replace, "non-word class \\W is not supported") + for (mode <- Seq(RegexFindMode, RegexReplaceMode)) { + assertUnsupported("\\W", mode, "non-word class \\W is not supported") } } @@ -298,7 +298,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // testing with this limited set of characters finds issues much // faster than using the full ASCII set // CR and LF has been excluded due to known issues - doFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), replace = false) + doFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), RegexFindMode) } test("compare CPU and GPU: regexp replace simple regular expressions") { @@ -316,7 +316,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support some uses of line anchors in regexp_replace") { Seq("^$", "^", "$", "(^)($)", "(((^^^)))$", "^*", "$*", "^+", "$+", "^|$", "^^|$$").foreach( pattern => - assertUnsupported(pattern, replace = true, + assertUnsupported(pattern, RegexReplaceMode, "sequences that only contain '^' or '$' are not supported") ) } @@ -332,27 +332,27 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { // testing with this limited set of characters finds issues much // faster than using the full ASCII set // LF has been excluded due to known issues - doFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), replace = true) + doFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), RegexReplaceMode) } test("compare CPU and GPU: regexp find fuzz test printable ASCII chars plus CR, LF, and TAB") { // CR and LF has been excluded due to known issues - doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\r\n\t"), replace = false) + doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\r\n\t"), RegexFindMode) } test("compare CPU and GPU: fuzz test ASCII chars") { // LF has been excluded due to known issues val chars = (0x00 to 0x7F) .map(_.toChar) - doFuzzTest(Some(chars.mkString), replace = true) + doFuzzTest(Some(chars.mkString), RegexReplaceMode) } test("compare CPU and GPU: regexp find fuzz test all chars") { // this test cannot be enabled until we support CR and LF - doFuzzTest(None, replace = false) + doFuzzTest(None, RegexFindMode) } - private def doFuzzTest(validChars: Option[String], replace: Boolean) { + private def doFuzzTest(validChars: Option[String], mode: RegexMode) { val r = new EnhancedRandom(new Random(seed = 0L), options = FuzzerOptions(validChars, maxStringLen = 12)) @@ -365,13 +365,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { while (patterns.size < 5000) { val pattern = r.nextString() if (!patterns.contains(pattern)) { - if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, replace)).isSuccess) { + if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, mode)).isSuccess) { patterns += pattern } } } - if (replace) { + if (mode == RegexReplaceMode) { assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data) } else { assertCpuGpuMatchesRegexpFind(patterns.toSeq, data) @@ -379,15 +379,67 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("AST fuzz test - regexp_find") { - doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), replace = false) + doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_FIND), RegexFindMode) } test("AST fuzz test - regexp_replace") { - doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), replace = true) + doAstFuzzTest(Some(REGEXP_LIMITED_CHARS_REPLACE), RegexReplaceMode) } - private def doAstFuzzTest(validChars: Option[String], replace: Boolean) { + test("string split - limit < 0") { + val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]") + val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo") + for (limit <- Seq(Integer.MIN_VALUE, -2, -1)) { + doStringSplitTest(patterns, data, limit) + } + } + + test("string split - limit > 1") { + val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]") + val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo") + for (limit <- Seq(2, 5, Integer.MAX_VALUE)) { + doStringSplitTest(patterns, data, limit) + } + } + + test("string split fuzz") { + val (data, patterns) = generateDataAndPatterns(Some(REGEXP_LIMITED_CHARS_REPLACE), + RegexSplitMode) + for (limit <- Seq(-2, -1, 2, 5)) { + doStringSplitTest(patterns, data, limit) + } + } + + def doStringSplitTest(patterns: Set[String], data: Seq[String], limit: Int) { + for (pattern <- patterns) { + val cpu = cpuSplit(pattern, data, limit) + val cudfPattern = new CudfRegexTranspiler(RegexSplitMode).transpile(pattern) + val gpu = gpuSplit(cudfPattern, data, limit) + assert(cpu.length == gpu.length) + for (i <- cpu.indices) { + val cpuArray = cpu(i) + val gpuArray = gpu(i) + if (!cpuArray.sameElements(gpuArray)) { + fail(s"string_split pattern=${toReadableString(pattern)} " + + s"data=${toReadableString(data(i))} limit=$limit " + + s"\nCPU [${cpuArray.length}]: ${toReadableString(cpuArray.mkString(", "))} " + + s"\nGPU [${gpuArray.length}]: ${toReadableString(gpuArray.mkString(", "))}") + } + } + } + } + + private def doAstFuzzTest(validChars: Option[String], mode: RegexMode) { + val (data, patterns) = generateDataAndPatterns(validChars, mode) + if (mode == RegexReplaceMode) { + assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data) + } else { + assertCpuGpuMatchesRegexpFind(patterns.toSeq, data) + } + } + private def generateDataAndPatterns(validChars: Option[String], mode: RegexMode) + : (Seq[String], Set[String]) = { val r = new EnhancedRandom(new Random(seed = 0L), FuzzerOptions(validChars, maxStringLen = 12)) @@ -401,23 +453,18 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { while (patterns.size < 5000) { val pattern = fuzzer.generate(0).toRegexString if (!patterns.contains(pattern)) { - if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, replace)).isSuccess) { + if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern, mode)).isSuccess) { patterns += pattern } } } - - if (replace) { - assertCpuGpuMatchesRegexpReplace(patterns.toSeq, data) - } else { - assertCpuGpuMatchesRegexpFind(patterns.toSeq, data) - } + (data, patterns.toSet) } private def assertCpuGpuMatchesRegexpFind(javaPatterns: Seq[String], input: Seq[String]) = { for (javaPattern <- javaPatterns) { val cpu = cpuContains(javaPattern, input) - val cudfPattern = new CudfRegexTranspiler(replace = false).transpile(javaPattern) + val cudfPattern = new CudfRegexTranspiler(RegexFindMode).transpile(javaPattern) val gpu = try { gpuContains(cudfPattern, input) } catch { @@ -440,7 +487,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { input: Seq[String]) = { for (javaPattern <- javaPatterns) { val cpu = cpuReplace(javaPattern, input) - val cudfPattern = new CudfRegexTranspiler(replace = true).transpile(javaPattern) + val cudfPattern = new CudfRegexTranspiler(RegexReplaceMode).transpile(javaPattern) val gpu = try { gpuReplace(cudfPattern, input) } catch { @@ -509,18 +556,36 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { input.map(s => p.matcher(s).replaceAll(REPLACE_STRING)).toArray } + private def cpuSplit(pattern: String, input: Seq[String], limit: Int): Seq[Array[String]] = { + input.map(s => s.split(pattern, limit)) + } + + private def gpuSplit(pattern: String, input: Seq[String], limit: Int): Seq[Array[String]] = { + val isRegex = RegexParser.isRegExpString(pattern) + withResource(ColumnVector.fromStrings(input: _*)) { cv => + withResource(cv.stringSplitRecord(pattern, limit, isRegex)) { x => + withResource(x.copyToHost()) { hcv => + (0 until hcv.getRowCount.toInt).map(i => { + val list = hcv.getList(i) + list.toArray(new Array[String](list.size())) + }) + } + } + } + } + private def doTranspileTest(pattern: String, expected: String) { - val transpiled: String = transpile(pattern, replace = false) + val transpiled: String = transpile(pattern, RegexFindMode) assert(toReadableString(transpiled) === toReadableString(expected)) } - private def transpile(pattern: String, replace: Boolean): String = { - new CudfRegexTranspiler(replace).transpile(pattern) + private def transpile(pattern: String, mode: RegexMode): String = { + new CudfRegexTranspiler(mode).transpile(pattern) } - private def assertUnsupported(pattern: String, replace: Boolean, message: String): Unit = { + private def assertUnsupported(pattern: String, mode: RegexMode, message: String): Unit = { val e = intercept[RegexUnsupportedException] { - transpile(pattern, replace) + transpile(pattern, mode) } assert(e.getMessage.startsWith(message), pattern) }