diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index 72f377a031e..3d4d8517b18 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -28,7 +28,7 @@ else: pytestmark = pytest.mark.regexp -_regexp_conf = { 'spark.rapids.sql.regexp.enabled': 'true' } +_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True } def mk_str_gen(pattern): return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}') @@ -804,6 +804,59 @@ def test_regexp_replace_unicode_support(): ), conf=_regexp_conf) +@allow_non_gpu('ProjectExec', 'RegExpReplace') +def test_regexp_replace_fallback(): + gen = mk_str_gen('[abcdef]{0,2}') + + conf = { 'spark.rapids.sql.regexp.enabled': 'false' } + + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "[a-z]+", "PROD")', + 'REGEXP_REPLACE(a, "aa", "PROD")', + ), + cpu_fallback_class_name='RegExpReplace', + conf=conf + ) + +@pytest.mark.parametrize("regexp_enabled", ['true', 'false']) +def test_regexp_replace_simple(regexp_enabled): + gen = mk_str_gen('[abcdef]{0,2}') + + conf = { 'spark.rapids.sql.regexp.enabled': regexp_enabled } + + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "aa", "PROD")', + 'REGEXP_REPLACE(a, "ab", "PROD")', + 'REGEXP_REPLACE(a, "ae", "PROD")', + 'REGEXP_REPLACE(a, "bc", "PROD")', + 'REGEXP_REPLACE(a, "fa", "PROD")' + ), + conf=conf + ) + +@pytest.mark.parametrize("regexp_enabled", ['true', 'false']) +def test_regexp_replace_multi_optimization(regexp_enabled): + gen = mk_str_gen('[abcdef]{0,2}') + + conf = { 'spark.rapids.sql.regexp.enabled': regexp_enabled } + + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "aa|bb", "PROD")', + 'REGEXP_REPLACE(a, "(aa)|(bb)", "PROD")', + 'REGEXP_REPLACE(a, "aa|bb|cc", "PROD")', + 'REGEXP_REPLACE(a, "(aa)|(bb)|(cc)", "PROD")', + 'REGEXP_REPLACE(a, "aa|bb|cc|dd", "PROD")', + 'REGEXP_REPLACE(a, "(aa|bb)|(cc|dd)", "PROD")', + 'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee", "PROD")', + 'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee|ff", "PROD")' + ), + conf=conf + ) + + def test_regexp_split_unicode_support(): data_gen = mk_str_gen('([bf]o{0,2}青){1,7}') \ .with_special_case('boo青and青foo') @@ -836,7 +889,7 @@ def test_regexp_memory_fallback(): ), cpu_fallback_class_name='RLike', conf={ - 'spark.rapids.sql.regexp.enabled': 'true', + 'spark.rapids.sql.regexp.enabled': True, 'spark.rapids.sql.regexp.maxStateMemoryBytes': '10', 'spark.rapids.sql.batchSizeBytes': '20' # 1 row in the batch } @@ -858,7 +911,7 @@ def test_regexp_memory_ok(): 'a rlike "1|2|3|4|5|6"' ), conf={ - 'spark.rapids.sql.regexp.enabled': 'true', + 'spark.rapids.sql.regexp.enabled': True, 'spark.rapids.sql.regexp.maxStateMemoryBytes': '12', 'spark.rapids.sql.batchSizeBytes': '20' # 1 row in the batch } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index ce3eca2f630..f858aeee607 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -593,6 +593,11 @@ object GpuOverrides extends Logging { lit.value == null } + def isSupportedStringReplacePattern(strLit: String): Boolean = { + // check for regex special characters, except for \u0000 which we can support + !regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern)) + } + def isSupportedStringReplacePattern(exp: Expression): Boolean = { extractLit(exp) match { case Some(Literal(null, _)) => false @@ -602,7 +607,7 @@ object GpuOverrides extends Logging { false } else { // check for regex special characters, except for \u0000 which we can support - !regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern)) + isSupportedStringReplacePattern(strLit) } case _ => false } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRegExpReplaceMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRegExpReplaceMeta.scala index c114b870bff..9cde88c5815 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRegExpReplaceMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRegExpReplaceMeta.scala @@ -16,10 +16,18 @@ package com.nvidia.spark.rapids import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} -import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils, GpuStringReplace} +import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils} import org.apache.spark.sql.types.DataTypes import org.apache.spark.unsafe.types.UTF8String +trait GpuRegExpReplaceOpt extends Serializable + +@SerialVersionUID(100L) +object GpuRegExpStringReplace extends GpuRegExpReplaceOpt + +@SerialVersionUID(100L) +object GpuRegExpStringReplaceMulti extends GpuRegExpReplaceOpt + class GpuRegExpReplaceMeta( expr: RegExpReplace, conf: RapidsConf, @@ -30,11 +38,11 @@ class GpuRegExpReplaceMeta( private var javaPattern: Option[String] = None private var cudfPattern: Option[String] = None private var replacement: Option[String] = None - private var canUseGpuStringReplace = false + private var searchList: Option[Seq[String]] = None + private var replaceOpt: Option[GpuRegExpReplaceOpt] = None private var containsBackref: Boolean = false override def tagExprForGpu(): Unit = { - GpuRegExpUtils.tagForRegExpEnabled(this) replacement = expr.rep match { case Literal(s: UTF8String, DataTypes.StringType) if s != null => Some(s.toString) case _ => None @@ -42,25 +50,32 @@ class GpuRegExpReplaceMeta( expr.regexp match { case Literal(s: UTF8String, DataTypes.StringType) if s != null => - if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { - canUseGpuStringReplace = true - } else { - try { - javaPattern = Some(s.toString()) - val (pat, repl) = - new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None, - replacement) - GpuRegExpUtils.validateRegExpComplexity(this, pat) - cudfPattern = Some(pat.toRegexString) - repl.map { r => GpuRegExpUtils.backrefConversion(r.toRegexString) }.foreach { - case (hasBackref, convertedRep) => - containsBackref = hasBackref - replacement = Some(GpuRegExpUtils.unescapeReplaceString(convertedRep)) - } - } catch { - case e: RegexUnsupportedException => - willNotWorkOnGpu(e.getMessage) + javaPattern = Some(s.toString()) + try { + val (pat, repl) = + new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None, + replacement) + repl.map { r => GpuRegExpUtils.backrefConversion(r.toRegexString) }.foreach { + case (hasBackref, convertedRep) => + containsBackref = hasBackref + replacement = Some(GpuRegExpUtils.unescapeReplaceString(convertedRep)) } + if (!containsBackref && GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) { + replaceOpt = Some(GpuRegExpStringReplace) + } else { + searchList = GpuRegExpUtils.getChoicesFromRegex(pat) + searchList match { + case Some(_) if !containsBackref => + replaceOpt = Some(GpuRegExpStringReplaceMulti) + case _ => + GpuRegExpUtils.tagForRegExpEnabled(this) + GpuRegExpUtils.validateRegExpComplexity(this, pat) + cudfPattern = Some(pat.toRegexString) + } + } + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) } case _ => @@ -82,19 +97,27 @@ class GpuRegExpReplaceMeta( // ignore the pos expression which must be a literal 1 after tagging check require(childExprs.length == 4, s"Unexpected child count for RegExpReplace: ${childExprs.length}") - if (canUseGpuStringReplace) { - GpuStringReplace(lhs, regexp, rep) - } else { - (javaPattern, cudfPattern, replacement) match { - case (Some(javaPattern), Some(cudfPattern), Some(cudfReplacement)) => - if (containsBackref) { - GpuRegExpReplaceWithBackref(lhs, regexp, rep)(cudfPattern, cudfReplacement) - } else { - GpuRegExpReplace(lhs, regexp, rep)(javaPattern, cudfPattern, cudfReplacement) - } - case _ => - throw new IllegalStateException("Expression has not been tagged correctly") - } + replaceOpt match { + case None => + (javaPattern, cudfPattern, replacement) match { + case (Some(javaPattern), Some(cudfPattern), Some(cudfReplacement)) => + if (containsBackref) { + GpuRegExpReplaceWithBackref(lhs, regexp, rep)(cudfPattern, cudfReplacement) + } else { + GpuRegExpReplace(lhs, regexp, rep)(javaPattern, cudfPattern, cudfReplacement, + None, None) + } + case _ => + throw new IllegalStateException("Expression has not been tagged correctly") + } + case _ => + (javaPattern, replacement) match { + case (Some(javaPattern), Some(replacement)) => + GpuRegExpReplace(lhs, regexp, rep)(javaPattern, javaPattern, replacement, + searchList, replaceOpt) + case _ => + throw new IllegalStateException("Expression has not been tagged correctly") + } } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index ca25dd1f3d4..d02bb7a0b32 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -742,11 +742,39 @@ case class GpuStringRepeat(input: Expression, repeatTimes: Expression) } +trait HasGpuStringReplace extends Arm { + def doStringReplace( + strExpr: GpuColumnVector, + searchExpr: GpuScalar, + replaceExpr: GpuScalar): ColumnVector = { + // When search or replace string is null, return all nulls like the CPU does. + if (!searchExpr.isValid || !replaceExpr.isValid) { + GpuColumnVector.columnVectorFromNull(strExpr.getRowCount.toInt, StringType) + } else if (searchExpr.getValue.asInstanceOf[UTF8String].numChars() == 0) { + // Return original string if search string is empty + strExpr.getBase.asStrings() + } else { + strExpr.getBase.stringReplace(searchExpr.getBase, replaceExpr.getBase) + } + } + + def doStringReplaceMulti( + strExpr: GpuColumnVector, + search: Seq[String], + replacement: String): ColumnVector = { + withResource(ColumnVector.fromStrings(search: _*)) { targets => + withResource(ColumnVector.fromStrings(replacement)) { repls => + strExpr.getBase.stringReplace(targets, repls) + } + } + } +} + case class GpuStringReplace( srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression) - extends GpuTernaryExpression with ImplicitCastInputTypes { + extends GpuTernaryExpression with ImplicitCastInputTypes with HasGpuStringReplace { override def dataType: DataType = srcExpr.dataType @@ -794,15 +822,7 @@ case class GpuStringReplace( strExpr: GpuColumnVector, searchExpr: GpuScalar, replaceExpr: GpuScalar): ColumnVector = { - // When search or replace string is null, return all nulls like the CPU does. - if (!searchExpr.isValid || !replaceExpr.isValid) { - GpuColumnVector.columnVectorFromNull(strExpr.getRowCount.toInt, StringType) - } else if (searchExpr.getValue.asInstanceOf[UTF8String].numChars() == 0) { - // Return original string if search string is empty - strExpr.getBase.asStrings() - } else { - strExpr.getBase.stringReplace(searchExpr.getBase, replaceExpr.getBase) - } + doStringReplace(strExpr, searchExpr, replaceExpr) } override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar, @@ -996,6 +1016,40 @@ object GpuRegExpUtils { countGroups(parseAST(pattern)) } + def getChoicesFromRegex(regex: RegexAST): Option[Seq[String]] = { + regex match { + case RegexGroup(_, t, None) => + getChoicesFromRegex(t) + case RegexChoice(a, b) => + getChoicesFromRegex(a) match { + case Some(la) => + getChoicesFromRegex(b) match { + case Some(lb) => Some(la ++ lb) + case _ => None + } + case _ => None + } + case RegexSequence(parts) => + if (GpuOverrides.isSupportedStringReplacePattern(regex.toRegexString)) { + Some(Seq(regex.toRegexString)) + } else { + parts.foldLeft(Some(Seq[String]()): Option[Seq[String]]) { (m: Option[Seq[String]], r) => + getChoicesFromRegex(r) match { + case Some(l) => m.map(_ ++ l) + case _ => None + } + } + } + case _ => + if (GpuOverrides.isSupportedStringReplacePattern(regex.toRegexString)) { + Some(Seq(regex.toRegexString)) + } else { + None + } + } + } + + } class GpuRLikeMeta( @@ -1114,11 +1168,13 @@ case class GpuRegExpReplace( replaceExpr: Expression) (javaRegexpPattern: String, cudfRegexPattern: String, - cudfReplacementString: String) - extends GpuRegExpTernaryBase with ImplicitCastInputTypes { + cudfReplacementString: String, + searchList: Option[Seq[String]], + replaceOpt: Option[GpuRegExpReplaceOpt]) + extends GpuRegExpTernaryBase with ImplicitCastInputTypes with HasGpuStringReplace { override def otherCopyArgs: Seq[AnyRef] = Seq(javaRegexpPattern, - cudfRegexPattern, cudfReplacementString) + cudfRegexPattern, cudfReplacementString, searchList, replaceOpt) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) override def first: Expression = srcExpr @@ -1129,7 +1185,7 @@ case class GpuRegExpReplace( cudfRegexPattern: String, cudfReplacementString: String) = { this(srcExpr, searchExpr, GpuLiteral("", StringType))(javaRegexpPattern, - cudfRegexPattern, cudfReplacementString) + cudfRegexPattern, cudfReplacementString, None, None) } override def doColumnar( @@ -1139,28 +1195,41 @@ case class GpuRegExpReplace( // For empty strings and a regex containing only a zero-match repetition, // the behavior in some versions of Spark is different. // see https://github.com/NVIDIA/spark-rapids/issues/5456 - val prog = new RegexProgram(cudfRegexPattern, CaptureGroups.NON_CAPTURE) - if (SparkShimImpl.reproduceEmptyStringBug && - GpuRegExpUtils.isEmptyRepetition(javaRegexpPattern)) { - val isEmpty = withResource(strExpr.getBase.getCharLengths) { len => - withResource(Scalar.fromInt(0)) { zero => - len.equalTo(zero) + replaceOpt match { + case Some(GpuRegExpStringReplace) => + doStringReplace(strExpr, searchExpr, replaceExpr) + case Some(GpuRegExpStringReplaceMulti) => + searchList match { + case Some(searches) => + doStringReplaceMulti(strExpr, searches, cudfReplacementString) + case _ => + throw new IllegalStateException("Need a replace") } - } - withResource(isEmpty) { _ => - withResource(GpuScalar.from("", DataTypes.StringType)) { emptyString => - withResource(GpuScalar.from(cudfReplacementString, DataTypes.StringType)) { rep => - withResource(strExpr.getBase.replaceRegex(prog, rep)) { replacement => - isEmpty.ifElse(emptyString, replacement) + case _ => + val prog = new RegexProgram(cudfRegexPattern, CaptureGroups.NON_CAPTURE) + if (SparkShimImpl.reproduceEmptyStringBug && + GpuRegExpUtils.isEmptyRepetition(javaRegexpPattern)) { + val isEmpty = withResource(strExpr.getBase.getCharLengths) { len => + withResource(Scalar.fromInt(0)) { zero => + len.equalTo(zero) + } + } + withResource(isEmpty) { _ => + withResource(GpuScalar.from("", DataTypes.StringType)) { emptyString => + withResource(GpuScalar.from(cudfReplacementString, DataTypes.StringType)) { rep => + withResource(strExpr.getBase.replaceRegex(prog, rep)) { replacement => + isEmpty.ifElse(emptyString, replacement) + } + } } } + } else { + withResource(Scalar.fromString(cudfReplacementString)) { rep => + strExpr.getBase.replaceRegex(prog, rep) + } } - } - } else { - withResource(Scalar.fromString(cudfReplacementString)) { rep => - strExpr.getBase.replaceRegex(prog, rep) - } } + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index 45c2cd65546..520b0356d38 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -185,6 +185,20 @@ class RegularExpressionParserSuite extends FunSuite { RegexSequence(ListBuffer(RegexChar('a')))), None)))) } + test("multiple choice (2)") { + assert(parse("aa|bb") == RegexChoice(RegexSequence(ListBuffer(RegexChar('a'), RegexChar('a'))), + RegexSequence(ListBuffer(RegexChar('b'), RegexChar('b'))) + )) + } + + test("multiple choice (3)") { + assert(parse("aa|bb|cc") == + RegexChoice(RegexSequence(ListBuffer(RegexChar('a'), RegexChar('a'))), + RegexChoice(RegexSequence(ListBuffer(RegexChar('b'), RegexChar('b'))), + RegexSequence(ListBuffer(RegexChar('c'), RegexChar('c'))) + ))) + } + test("group containing quantifier") { val e = intercept[RegexUnsupportedException] { parse("(?)") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala index 03a099631a9..29ff24fe7a6 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala @@ -16,11 +16,12 @@ package com.nvidia.spark.rapids -import org.scalatest.Ignore +import org.scalatest.{FunSuite, Ignore} import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.rapids.GpuRegExpUtils /* * Different versions of Java support different versions of Unicode. @@ -195,6 +196,29 @@ class StringOperatorsSuite extends SparkQueryCompareTestSuite { } } +class RegExpUtilsSuite extends FunSuite { + test("get list of choices from regexp for multi-replace") { + val regexChoices = Map( + "aa|bb" -> Seq("aa", "bb"), + "(aa)|(bb)" -> Seq("aa", "bb"), + "aa|bb|cc" -> Seq("aa", "bb", "cc"), + "(aa)|(bb)|(cc)" -> Seq("aa", "bb", "cc"), + "aa|bb|cc|dd" -> Seq("aa", "bb", "cc", "dd"), + "(aa|bb)|(cc|dd)" -> Seq("aa", "bb", "cc", "dd"), + "aa|bb|cc|dd|ee" -> Seq("aa", "bb", "cc", "dd", "ee"), + "aa|bb|cc|dd|ee|ff" -> Seq("aa", "bb", "cc", "dd", "ee", "ff") + ) + + regexChoices.foreach { case (pattern, choices) => + val (ast, _) = (new CudfRegexTranspiler(RegexReplaceMode)).getTranspiledAST(pattern, + None, Some("")) + val result = GpuRegExpUtils.getChoicesFromRegex(ast) + assert(result.isDefined && result.forall(_ == choices)) + } + + } +} + /* * This isn't actually a test. It's just useful to help visualize what's going on when there are * differences present.