Skip to content

Commit

Permalink
Merge pull request #10954 from NVIDIA/branch-24.06
Browse files Browse the repository at this point in the history
[auto-merge] branch-24.06 to branch-24.08 [skip ci] [bot]
  • Loading branch information
nvauto authored May 31, 2024
2 parents 2a86bb5 + 022fdd1 commit bbdcac0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
44 changes: 29 additions & 15 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package com.nvidia.spark.rapids

import java.sql.SQLException

import scala.collection
import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.GpuOverrides.regexMetaChars
Expand Down Expand Up @@ -73,7 +72,7 @@ class RegexParser(pattern: String) {
sequence
}

def parseReplacementBase(): RegexAST = {
private def parseReplacementBase(): RegexAST = {
consume() match {
case '\\' =>
parseBackrefOrEscaped()
Expand Down Expand Up @@ -782,6 +781,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

@scala.annotation.tailrec
private def isRepetition(e: RegexAST, checkZeroLength: Boolean): Boolean = {
e match {
case RegexRepetition(_, _) if !checkZeroLength => true
Expand Down Expand Up @@ -1648,6 +1648,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}

@scala.annotation.tailrec
private def isEntirely(regex: RegexAST, f: RegexAST => Boolean): Boolean = {
regex match {
case RegexSequence(parts) if parts.nonEmpty =>
Expand All @@ -1672,6 +1673,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
})
}

@scala.annotation.tailrec
private def beginsWith(regex: RegexAST, f: RegexAST => Boolean): Boolean = {
regex match {
case RegexSequence(parts) if parts.nonEmpty =>
Expand All @@ -1687,6 +1689,7 @@ class CudfRegexTranspiler(mode: RegexMode) {

}

@scala.annotation.tailrec
private def endsWith(regex: RegexAST, f: RegexAST => Boolean): Boolean = {
regex match {
case RegexSequence(parts) if parts.nonEmpty =>
Expand Down Expand Up @@ -1760,7 +1763,7 @@ sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST {
}

sealed case class RegexGroup(capture: Boolean, term: RegexAST,
val lookahead: Option[RegexLookahead])
lookahead: Option[RegexLookahead])
extends RegexAST {
def this(capture: Boolean, term: RegexAST) = {
this(capture, term, None)
Expand Down Expand Up @@ -2028,6 +2031,7 @@ object RegexOptimizationType {

object RegexRewrite {

@scala.annotation.tailrec
private def removeBrackets(astLs: collection.Seq[RegexAST]): collection.Seq[RegexAST] = {
astLs match {
case collection.Seq(RegexGroup(_, term, None)) => removeBrackets(term.children())
Expand All @@ -2044,7 +2048,7 @@ object RegexRewrite {
*/
private def getPrefixRangePattern(astLs: collection.Seq[RegexAST]):
Option[(String, Int, Int, Int)] = {
val haveLiteralPrefix = isliteralString(astLs.dropRight(1))
val haveLiteralPrefix = isLiteralString(astLs.dropRight(1))
val endsWithRange = astLs.lastOption match {
case Some(RegexRepetition(
RegexCharacterClass(false, ListBuffer(RegexCharacterRange(a,b))),
Expand Down Expand Up @@ -2080,9 +2084,9 @@ object RegexRewrite {
}
}

private def isliteralString(astLs: collection.Seq[RegexAST]): Boolean = {
private def isLiteralString(astLs: collection.Seq[RegexAST]): Boolean = {
removeBrackets(astLs).forall {
case RegexChar(ch) if !regexMetaChars.contains(ch) => true
case RegexChar(ch) => !regexMetaChars.contains(ch)
case _ => false
}
}
Expand Down Expand Up @@ -2120,24 +2124,34 @@ object RegexRewrite {
* Matches the given regex ast to a regex optimization type for regex rewrite
* optimization.
*
* @param ast The Abstract Syntax Tree parsed from a regex pattern.
* @param ast unparsed children of the Abstract Syntax Tree parsed from a regex pattern.
* @return The `RegexOptimizationType` for the given pattern.
*/
def matchSimplePattern(ast: RegexAST): RegexOptimizationType = {
ast.children() match {
case (RegexChar('^') | RegexEscaped('A')) :: ast
if isliteralString(stripTailingWildcards(ast)) => {
// ^literal.* => startsWith literal
RegexOptimizationType.StartsWith(RegexCharsToString(stripTailingWildcards(ast)))
}
@scala.annotation.tailrec
def matchSimplePattern(ast: Seq[RegexAST]): RegexOptimizationType = {
ast match {
case (RegexChar('^') | RegexEscaped('A')) :: astTail =>
val noTrailingWildCards = stripTailingWildcards(astTail)
if (isLiteralString(noTrailingWildCards)) {
// ^literal.* => startsWith literal
RegexOptimizationType.StartsWith(RegexCharsToString(noTrailingWildCards))
} else {
val noWildCards = stripLeadingWildcards(noTrailingWildCards)
if (noWildCards.length == noTrailingWildCards.length) {
// TODO startsWith with PrefIxRange
RegexOptimizationType.NoOptimization
} else {
matchSimplePattern(astTail)
}
}
case astLs => {
val noStartsWithAst = stripTailingWildcards(stripLeadingWildcards(astLs))
val prefixRangeInfo = getPrefixRangePattern(noStartsWithAst)
if (prefixRangeInfo.isDefined) {
val (prefix, length, start, end) = prefixRangeInfo.get
// (literal[a-b]{x,y}) => prefix range pattern
RegexOptimizationType.PrefixRange(prefix, length, start, end)
} else if (isliteralString(noStartsWithAst)) {
} else if (isLiteralString(noStartsWithAst)) {
// literal.* or (literal).* => contains literal
RegexOptimizationType.Contains(RegexCharsToString(noStartsWithAst))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ class GpuRLikeMeta(
val originalPattern = str.toString
val regexAst = new RegexParser(originalPattern).parse()
if (conf.isRlikeRegexRewriteEnabled) {
rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst)
rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst.children())
}
val (transpiledAST, _) = new CudfRegexTranspiler(RegexFindMode)
.getTranspiledAST(regexAst, None, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
Unit = {
val results = patterns.map { pattern =>
val ast = new RegexParser(pattern).parse()
RegexRewrite.matchSimplePattern(ast)
RegexRewrite.matchSimplePattern(ast.children())
}
assert(results == excepted)
}
Expand Down Expand Up @@ -53,12 +53,23 @@ class RegularExpressionRewriteSuite extends AnyFunSuite {
"(.*)abc[0-9a-z]{1,3}(.*)",
"(.*)abc[0-9]{2}.*",
"^abc[0-9]{1,3}",
"火花急流[\u4e00-\u9fa5]{1}")
val excepted = Seq(PrefixRange("abc", 1, 48, 57),
NoOptimization,
PrefixRange("abc", 2, 48, 57),
"火花急流[\u4e00-\u9fa5]{1}",
"^[0-9]{6}",
"^[0-9]{3,10}",
"^.*[0-9]{6}",
"^(.*)[0-9]{3,10}"
)
val excepted = Seq(
PrefixRange("abc", 1, 48, 57),
PrefixRange("火花急流", 1, 19968, 40869))
NoOptimization, // prefix followed by a multi-range not supported
PrefixRange("abc", 2, 48, 57),
NoOptimization, // starts with PrefixRange not supported
PrefixRange("火花急流", 1, 19968, 40869),
NoOptimization, // starts with PrefixRange not supported
NoOptimization, // starts with PrefixRange not supported
PrefixRange("", 6, 48, 57),
PrefixRange("", 3, 48, 57)
)
verifyRewritePattern(patterns, excepted)
}
}

0 comments on commit bbdcac0

Please sign in to comment.