Skip to content

Commit

Permalink
Rewrite regex pattern literal[a-b]{x,y} to custom kernel in rlike (#8)
Browse files Browse the repository at this point in the history
* A hacky approach for regexpr rewrite

Signed-off-by: Haoyang Li <[email protected]>

* Use contains instead for that case

Signed-off-by: Haoyang Li <[email protected]>

* add config to switch

Signed-off-by: Haoyang Li <[email protected]>

* Rewrite some rlike expression to StartsWith/EndsWith/Contains

Signed-off-by: Haoyang Li <[email protected]>

* clean up

Signed-off-by: Haoyang Li <[email protected]>

* Draft code to adapt RegexParser in regex rewrite

Signed-off-by: Haoyang Li <[email protected]>

* clean up

Signed-off-by: Haoyang Li <[email protected]>

* Apply suggestions from code review

Co-authored-by: Gera Shegalov <[email protected]>

* A checkpoint before removing endsWith rewrite

Signed-off-by: Haoyang Li <[email protected]>

* Remove equalsTo and endsWith

Signed-off-by: Haoyang Li <[email protected]>

* clean up

Signed-off-by: Haoyang Li <[email protected]>

* address a comment

Signed-off-by: Haoyang Li <[email protected]>

* address a comment

Signed-off-by: Haoyang Li <[email protected]>

* address comments

Signed-off-by: Haoyang Li <[email protected]>

* fix 2.13 build

Signed-off-by: Haoyang Li <[email protected]>

* checkpoint before pattern matching => if

Signed-off-by: Haoyang Li <[email protected]>

* Add prefix range in regex parser rewrite

Signed-off-by: Haoyang Li <[email protected]>

* Address comments

Signed-off-by: Haoyang Li <[email protected]>

* wip

Signed-off-by: Haoyang Li <[email protected]>

* clean up

Signed-off-by: Haoyang Li <[email protected]>

* change some names

Signed-off-by: Haoyang Li <[email protected]>

* checkpoint before upmerge

Signed-off-by: Haoyang Li <[email protected]>

* add tests

Signed-off-by: Haoyang Li <[email protected]>

* Catch exceptions when trying to examine Iceberg scan for metadata queries (NVIDIA#10836)

Signed-off-by: Jason Lowe <[email protected]>

* Add NVTX ranges to identify Spark stages and tasks (NVIDIA#10826)

* Add NVTX ranges to identify Spark stages and tasks

Signed-off-by: Jason Lowe <[email protected]>

* scalastyle

---------

Signed-off-by: Jason Lowe <[email protected]>

---------

Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Jason Lowe <[email protected]>
Co-authored-by: Gera Shegalov <[email protected]>
Co-authored-by: Jason Lowe <[email protected]>
  • Loading branch information
3 people authored and nvliyuan committed May 28, 2024
1 parent 9b1f897 commit 85bd005
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 84 deletions.
6 changes: 5 additions & 1 deletion integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,11 @@ def test_rlike_rewrite_optimization():
'rlike(a, "^^abb")',
'rlike(a, "(.*)(.*)abb")',
'rlike(a, "(.*).*abb.*(.*).*")',
'rlike(a, ".*^abb$")'),
'rlike(a, ".*^abb$")',
'rlike(a, "ab[a-c]\{3\}")',
'rlike(a, "a[a-c]{1,3}")',
'rlike(a, "a[a-c]{1,}")',
'rlike(a, "a[a-c]+")'),
conf=_regexp_conf)

def test_regexp_replace_character_set_negated():
Expand Down
24 changes: 22 additions & 2 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@ import java.lang.reflect.InvocationTargetException
import java.net.URL
import java.time.ZoneId
import java.util.Properties
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.sys.process._
import scala.util.Try

import ai.rapids.cudf.{Cuda, CudaException, CudaFatalException, CudfException, MemoryCleaner}
import ai.rapids.cudf.{Cuda, CudaException, CudaFatalException, CudfException, MemoryCleaner, NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.RapidsConf.AllowMultipleJars
import com.nvidia.spark.rapids.RapidsPluginUtils.buildInfoEvent
import com.nvidia.spark.rapids.filecache.{FileCache, FileCacheLocalityManager, FileCacheLocalityMsg}
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
import org.apache.commons.lang3.exception.ExceptionUtils

import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskFailedReason}
import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskContext, TaskFailedReason}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.SparkListenerEvent
Expand Down Expand Up @@ -495,6 +496,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
var rapidsShuffleHeartbeatEndpoint: RapidsShuffleHeartbeatEndpoint = null
private lazy val extraExecutorPlugins =
RapidsPluginUtils.extraPlugins.map(_.executorPlugin()).filterNot(_ == null)
private val activeTaskNvtx = new ConcurrentHashMap[Thread, NvtxRange]()

override def init(
pluginContext: PluginContext,
Expand Down Expand Up @@ -687,15 +689,33 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
logDebug(s"Executor onTaskFailed: ${other.toString}")
}
extraExecutorPlugins.foreach(_.onTaskFailed(failureReason))
endTaskNvtx()
}

override def onTaskStart(): Unit = {
startTaskNvtx(TaskContext.get)
extraExecutorPlugins.foreach(_.onTaskStart())
ProfilerOnExecutor.onTaskStart()
}

override def onTaskSucceeded(): Unit = {
extraExecutorPlugins.foreach(_.onTaskSucceeded())
endTaskNvtx()
}

private def startTaskNvtx(taskCtx: TaskContext): Unit = {
val stageId = taskCtx.stageId()
val taskAttemptId = taskCtx.taskAttemptId()
val attemptNumber = taskCtx.attemptNumber()
activeTaskNvtx.put(Thread.currentThread(),
new NvtxRange(s"Stage $stageId Task $taskAttemptId-$attemptNumber", NvtxColor.DARK_GREEN))
}

private def endTaskNvtx(): Unit = {
val nvtx = activeTaskNvtx.remove(Thread.currentThread())
if (nvtx != null) {
nvtx.close()
}
}
}

Expand Down
105 changes: 75 additions & 30 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,20 @@ class CudfRegexTranspiler(mode: RegexMode) {
(cudfRegex.toRegexString, replacement.map(_.toRegexString))
}

def getTranspiledAST(
regex: RegexAST,
extractIndex: Option[Int],
repl: Option[String]): (RegexAST, Option[RegexReplacement]) = {
// if we have a replacement, parse the replacement string using the regex parser to account
// for backrefs
val replacement = repl.map(s => new RegexParser(s).parseReplacement(countCaptureGroups(regex)))

// validate that the regex is supported by cuDF
val cudfRegex = transpile(regex, extractIndex, replacement, None)

(cudfRegex, replacement)
}

/**
* Parse Java regular expression and translate into cuDF regular expression in AST form.
*
Expand All @@ -734,14 +748,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
repl: Option[String]): (RegexAST, Option[RegexReplacement]) = {
// parse the source regular expression
val regex = new RegexParser(pattern).parse()
// if we have a replacement, parse the replacement string using the regex parser to account
// for backrefs
val replacement = repl.map(s => new RegexParser(s).parseReplacement(countCaptureGroups(regex)))

// validate that the regex is supported by cuDF
val cudfRegex = transpile(regex, extractIndex, replacement, None)

(cudfRegex, replacement)
getTranspiledAST(regex, extractIndex, repl)
}

def transpileToSplittableString(e: RegexAST): Option[String] = {
Expand Down Expand Up @@ -2014,14 +2021,56 @@ sealed trait RegexOptimizationType
object RegexOptimizationType {
case class StartsWith(literal: String) extends RegexOptimizationType
case class Contains(literal: String) extends RegexOptimizationType
case class PrefixRange(literal: String, length: Int, rangeStart: Int, rangeEnd: Int)
extends RegexOptimizationType
case object NoOptimization extends RegexOptimizationType
}

object RegexRewriteUtils {
object RegexRewrite {

private def removeBrackets(astLs: collection.Seq[RegexAST]): collection.Seq[RegexAST] = {
astLs match {
case collection.Seq(RegexGroup(_, term, None)) => removeBrackets(term.children())
case _ => astLs
}
}
private def getPrefixRangePattern(astLs: collection.Seq[RegexAST]):
Option[(String, Int, Int, Int)] = {
val endsWithRange = astLs.last match {
case RegexRepetition(
RegexCharacterClass(false,ListBuffer(RegexCharacterRange(a,b))),
quantifier) => {
val (start, end) = (a, b) match {
case (RegexChar(start), RegexChar(end)) => (start, end)
case _ => return None
}
val length = quantifier match {
case QuantifierVariableLength(start, _) => start
case QuantifierFixedLength(len) => len
case SimpleQuantifier(ch) => ch match {
case '*' | '?' => 0
case '+' => 1
case _ => return None
}
case _ => return None
}
Some((length, start.toInt, end.toInt))
}
case _ => None
}
val literalPrefix = isliteralString(astLs.dropRight(1))
(literalPrefix, endsWithRange) match {
case (true, Some((length, start, end))) => {
val prefix = RegexCharsToString(astLs.dropRight(1))
Some((prefix, length, start, end))
}
case _ => None
}
}

private def isliteralString(astLs: collection.Seq[RegexAST]): Boolean = {
astLs.forall {
case RegexChar('^') | RegexChar('$') | RegexChar('.') => false
case RegexChar(_) => true
removeBrackets(astLs).forall {
case RegexChar(ch) if !regexMetaChars.contains(ch) => true
case _ => false
}
}
Expand All @@ -2048,7 +2097,7 @@ object RegexRewriteUtils {
}

private def RegexCharsToString(chars: collection.Seq[RegexAST]): String = {
chars.map {
removeBrackets(chars).map {
case RegexChar(ch) => ch
case _ => throw new IllegalArgumentException("Invalid character")
}.mkString
Expand All @@ -2063,30 +2112,26 @@ object RegexRewriteUtils {
*/
def matchSimplePattern(ast: RegexAST): RegexOptimizationType = {
ast.children() match {
case (RegexChar('^') | RegexEscaped('A')) :: RegexGroup(_, RegexSequence(parts), None) :: rest
if isliteralString(parts) && rest.forall(isWildcard) => {
// ^(literal).* => startsWith literal
RegexOptimizationType.StartsWith(RegexCharsToString(parts))
}
case (RegexChar('^') | RegexEscaped('A')) :: ast
if isliteralString(stripTailingWildcards(ast)) => {
// ^literal.* => startsWith literal
println("Starts with optimization")
RegexOptimizationType.StartsWith(RegexCharsToString(stripTailingWildcards(ast)))
}
case noStartsWithAst => stripLeadingWildcards(noStartsWithAst) match {
case RegexGroup(_, RegexSequence(parts), None) :: rest
if isliteralString(parts) && rest.forall(isWildcard) => {
// (literal).* => contains literal
RegexOptimizationType.Contains(RegexCharsToString(parts))
}
case ast if isliteralString(stripTailingWildcards(ast)) => {
// literal.* => contains literal
RegexOptimizationType.Contains(RegexCharsToString(stripTailingWildcards(ast)))
}
case _ => {
case astLs => {
val noStartsWithAst = stripTailingWildcards(stripLeadingWildcards(astLs))
val prefixRangeInfo = getPrefixRangePattern(noStartsWithAst)
if (prefixRangeInfo.isDefined) {
val (prefix, length, start, end) = prefixRangeInfo.get
// (literal[a-b]{x,y}) => prefix range pattern
RegexOptimizationType.PrefixRange(prefix, length, start, end)
} else if (isliteralString(noStartsWithAst)) {
// literal.* or (literal).* => contains literal
RegexOptimizationType.Contains(RegexCharsToString(noStartsWithAst))
} else {
RegexOptimizationType.NoOptimization
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.spark.rapids.iceberg

import scala.reflect.ClassTag
import scala.util.{Failure, Try}
import scala.util.{Failure, Success, Try}

import com.nvidia.spark.rapids.{FileFormatChecks, GpuScan, IcebergFormatType, RapidsConf, ReadFileOp, ScanMeta, ScanRule, ShimReflectionUtils}
import com.nvidia.spark.rapids.iceberg.spark.source.GpuSparkBatchQueryScan
Expand Down Expand Up @@ -48,8 +48,12 @@ class IcebergProviderImpl extends IcebergProvider {

FileFormatChecks.tag(this, a.readSchema(), IcebergFormatType, ReadFileOp)

if (GpuSparkBatchQueryScan.isMetadataScan(a)) {
willNotWorkOnGpu("scan is a metadata scan")
Try {
GpuSparkBatchQueryScan.isMetadataScan(a)
} match {
case Success(true) => willNotWorkOnGpu("scan is a metadata scan")
case Failure(e) => willNotWorkOnGpu(s"error examining CPU Iceberg scan: $e")
case _ =>
}

convertedScan match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, Co
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.{CastStrings, StringDigitsPattern}
import com.nvidia.spark.rapids.jni.CastStrings
import com.nvidia.spark.rapids.jni.RegexRewriteUtils
import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -1066,38 +1067,27 @@ object RegexprPart {
}

class GpuRLikeMeta(
expr: RLike,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule) extends BinaryExprMeta[RLike](expr, conf, parent, rule) {

private var originalPattern: String = ""
expr: RLike,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule) extends BinaryExprMeta[RLike](expr, conf, parent, rule) {
import RegexOptimizationType._
private var pattern: Option[String] = None

def optimizeSimplePattern(lhs: Expression, rhs: Expression): GpuExpression = {
import RegexOptimizationType._
val originalAst = new RegexParser(originalPattern).parse()
RegexRewriteUtils.matchSimplePattern(originalAst) match {
case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(s, StringType))
case NoOptimization => {
val patternStr = pattern.getOrElse(throw new IllegalStateException(
"Expression has not been tagged with cuDF regex pattern"))
GpuRLike(lhs, rhs, patternStr)
}
case _ => throw new IllegalStateException("Unexpected optimization type")
}
}
private var rewriteOptimizationType: RegexOptimizationType = NoOptimization

override def tagExprForGpu(): Unit = {
GpuRegExpUtils.tagForRegExpEnabled(this)
expr.right match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
// verify that we support this regex and can transpile it to cuDF format
originalPattern = str.toString
val originalPattern = str.toString
val regexAst = new RegexParser(originalPattern).parse()
if (conf.isRlikeRegexRewriteEnabled) {
rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst)
}
val (transpiledAST, _) = new CudfRegexTranspiler(RegexFindMode)
.getTranspiledAST(originalPattern, None, None)
.getTranspiledAST(regexAst, None, None)
GpuRegExpUtils.validateRegExpComplexity(this, transpiledAST)
pattern = Some(transpiledAST.toRegexString)
} catch {
Expand All @@ -1110,56 +1100,60 @@ class GpuRLikeMeta(
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
if (conf.isRlikeRegexRewriteEnabled) {
// if the pattern can be converted to a startswith or endswith pattern, we can use
// GpuStartsWith, GpuEndsWith or GpuContains instead to get better performance
optimizeSimplePattern(lhs, rhs)
} else {
val patternStr = pattern.getOrElse(throw new IllegalStateException(
rewriteOptimizationType match {
case NoOptimization => {
val patternStr = pattern.getOrElse(throw new IllegalStateException(
"Expression has not been tagged with cuDF regex pattern"))
GpuRLike(lhs, rhs, patternStr)
GpuRLike(lhs, rhs, patternStr)
}
case StartsWith(s) => GpuStartsWith(lhs, GpuLiteral(s, StringType))
case Contains(s) => GpuContains(lhs, GpuLiteral(s, StringType))
case PrefixRange(s, length, start, end) =>
GpuLiteralRangePattern(lhs, GpuLiteral(s, StringType), length, start, end)
case _ => throw new IllegalStateException("Unexpected optimization type")
}
}
}

case class GpuStringDigits(left: Expression, right: Expression, from: Int, start: Int, end: Int)
extends GpuBinaryExpressionArgsAnyScalar with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = BooleanType
case class GpuRLike(left: Expression, right: Expression, pattern: String)
extends GpuBinaryExpressionArgsAnyScalar
with ImplicitCastInputTypes
with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
override def toString: String = s"$left gpurlike $right"

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
StringDigitsPattern.stringDigitsPattern(lhs.getBase, rhs.getBase, from, start, end)
lhs.getBase.containsRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE))
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs =>
doColumnar(expandedLhs, rhs)
}
}

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)

override def dataType: DataType = BooleanType
}

case class GpuRLike(left: Expression, right: Expression, pattern: String)
extends GpuBinaryExpressionArgsAnyScalar
with ImplicitCastInputTypes
with NullIntolerant {
case class GpuLiteralRangePattern(left: Expression, right: Expression,
from: Int, start: Int, end: Int)
extends GpuBinaryExpressionArgsAnyScalar with ImplicitCastInputTypes with NullIntolerant {

override def toString: String = s"$left gpurlike $right"
override def dataType: DataType = BooleanType

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
lhs.getBase.containsRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE))
RegexRewriteUtils.literalRangePattern(lhs.getBase, rhs.getBase, from, start, end)
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs =>
doColumnar(expandedLhs, rhs)
}
}

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)

override def dataType: DataType = BooleanType
}

abstract class GpuRegExpTernaryBase
Expand Down
Loading

0 comments on commit 85bd005

Please sign in to comment.