diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 45d5e07dd73..0f5ada9f7fa 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -17,7 +17,6 @@ package com.nvidia.spark.rapids import java.sql.SQLException -import scala.collection import scala.collection.mutable.ListBuffer import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars @@ -73,7 +72,7 @@ class RegexParser(pattern: String) { sequence } - def parseReplacementBase(): RegexAST = { + private def parseReplacementBase(): RegexAST = { consume() match { case '\\' => parseBackrefOrEscaped() @@ -782,6 +781,7 @@ class CudfRegexTranspiler(mode: RegexMode) { } } + @scala.annotation.tailrec private def isRepetition(e: RegexAST, checkZeroLength: Boolean): Boolean = { e match { case RegexRepetition(_, _) if !checkZeroLength => true @@ -1648,6 +1648,7 @@ class CudfRegexTranspiler(mode: RegexMode) { } } + @scala.annotation.tailrec private def isEntirely(regex: RegexAST, f: RegexAST => Boolean): Boolean = { regex match { case RegexSequence(parts) if parts.nonEmpty => @@ -1672,6 +1673,7 @@ class CudfRegexTranspiler(mode: RegexMode) { }) } + @scala.annotation.tailrec private def beginsWith(regex: RegexAST, f: RegexAST => Boolean): Boolean = { regex match { case RegexSequence(parts) if parts.nonEmpty => @@ -1687,6 +1689,7 @@ class CudfRegexTranspiler(mode: RegexMode) { } + @scala.annotation.tailrec private def endsWith(regex: RegexAST, f: RegexAST => Boolean): Boolean = { regex match { case RegexSequence(parts) if parts.nonEmpty => @@ -1760,7 +1763,7 @@ sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST { } sealed case class RegexGroup(capture: Boolean, term: RegexAST, - val lookahead: Option[RegexLookahead]) + lookahead: Option[RegexLookahead]) extends RegexAST { def this(capture: Boolean, term: RegexAST) = { this(capture, term, None) @@ -2028,6 +2031,7 @@ object RegexOptimizationType { object RegexRewrite { + @scala.annotation.tailrec private def removeBrackets(astLs: collection.Seq[RegexAST]): collection.Seq[RegexAST] = { astLs match { case collection.Seq(RegexGroup(_, term, None)) => removeBrackets(term.children()) @@ -2044,7 +2048,7 @@ object RegexRewrite { */ private def getPrefixRangePattern(astLs: collection.Seq[RegexAST]): Option[(String, Int, Int, Int)] = { - val haveLiteralPrefix = isliteralString(astLs.dropRight(1)) + val haveLiteralPrefix = isLiteralString(astLs.dropRight(1)) val endsWithRange = astLs.lastOption match { case Some(RegexRepetition( RegexCharacterClass(false, ListBuffer(RegexCharacterRange(a,b))), @@ -2080,9 +2084,9 @@ object RegexRewrite { } } - private def isliteralString(astLs: collection.Seq[RegexAST]): Boolean = { + private def isLiteralString(astLs: collection.Seq[RegexAST]): Boolean = { removeBrackets(astLs).forall { - case RegexChar(ch) if !regexMetaChars.contains(ch) => true + case RegexChar(ch) => !regexMetaChars.contains(ch) case _ => false } } @@ -2120,16 +2124,26 @@ object RegexRewrite { * Matches the given regex ast to a regex optimization type for regex rewrite * optimization. * - * @param ast The Abstract Syntax Tree parsed from a regex pattern. + * @param ast unparsed children of the Abstract Syntax Tree parsed from a regex pattern. * @return The `RegexOptimizationType` for the given pattern. */ - def matchSimplePattern(ast: RegexAST): RegexOptimizationType = { - ast.children() match { - case (RegexChar('^') | RegexEscaped('A')) :: ast - if isliteralString(stripTailingWildcards(ast)) => { - // ^literal.* => startsWith literal - RegexOptimizationType.StartsWith(RegexCharsToString(stripTailingWildcards(ast))) - } + @scala.annotation.tailrec + def matchSimplePattern(ast: Seq[RegexAST]): RegexOptimizationType = { + ast match { + case (RegexChar('^') | RegexEscaped('A')) :: astTail => + val noTrailingWildCards = stripTailingWildcards(astTail) + if (isLiteralString(noTrailingWildCards)) { + // ^literal.* => startsWith literal + RegexOptimizationType.StartsWith(RegexCharsToString(noTrailingWildCards)) + } else { + val noWildCards = stripLeadingWildcards(noTrailingWildCards) + if (noWildCards.length == noTrailingWildCards.length) { + // TODO startsWith with PrefIxRange + RegexOptimizationType.NoOptimization + } else { + matchSimplePattern(astTail) + } + } case astLs => { val noStartsWithAst = stripTailingWildcards(stripLeadingWildcards(astLs)) val prefixRangeInfo = getPrefixRangePattern(noStartsWithAst) @@ -2137,7 +2151,7 @@ object RegexRewrite { val (prefix, length, start, end) = prefixRangeInfo.get // (literal[a-b]{x,y}) => prefix range pattern RegexOptimizationType.PrefixRange(prefix, length, start, end) - } else if (isliteralString(noStartsWithAst)) { + } else if (isLiteralString(noStartsWithAst)) { // literal.* or (literal).* => contains literal RegexOptimizationType.Contains(RegexCharsToString(noStartsWithAst)) } else { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index b875c84edbf..8fea4014149 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1073,7 +1073,7 @@ class GpuRLikeMeta( val originalPattern = str.toString val regexAst = new RegexParser(originalPattern).parse() if (conf.isRlikeRegexRewriteEnabled) { - rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst) + rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst.children()) } val (transpiledAST, _) = new CudfRegexTranspiler(RegexFindMode) .getTranspiledAST(regexAst, None, None) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala index a9ef6364aac..a140f4123f4 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionRewriteSuite.scala @@ -23,7 +23,7 @@ class RegularExpressionRewriteSuite extends AnyFunSuite { Unit = { val results = patterns.map { pattern => val ast = new RegexParser(pattern).parse() - RegexRewrite.matchSimplePattern(ast) + RegexRewrite.matchSimplePattern(ast.children()) } assert(results == excepted) } @@ -53,12 +53,23 @@ class RegularExpressionRewriteSuite extends AnyFunSuite { "(.*)abc[0-9a-z]{1,3}(.*)", "(.*)abc[0-9]{2}.*", "^abc[0-9]{1,3}", - "火花急流[\u4e00-\u9fa5]{1}") - val excepted = Seq(PrefixRange("abc", 1, 48, 57), - NoOptimization, - PrefixRange("abc", 2, 48, 57), + "火花急流[\u4e00-\u9fa5]{1}", + "^[0-9]{6}", + "^[0-9]{3,10}", + "^.*[0-9]{6}", + "^(.*)[0-9]{3,10}" + ) + val excepted = Seq( PrefixRange("abc", 1, 48, 57), - PrefixRange("火花急流", 1, 19968, 40869)) + NoOptimization, // prefix followed by a multi-range not supported + PrefixRange("abc", 2, 48, 57), + NoOptimization, // starts with PrefixRange not supported + PrefixRange("火花急流", 1, 19968, 40869), + NoOptimization, // starts with PrefixRange not supported + NoOptimization, // starts with PrefixRange not supported + PrefixRange("", 6, 48, 57), + PrefixRange("", 3, 48, 57) + ) verifyRewritePattern(patterns, excepted) } }