Skip to content

Commit

Permalink
Use [\u4e00-\u9fa5]+ regex rewrite kernel (#7)
Browse files Browse the repository at this point in the history
* A hacky approach for regexpr rewrite

Signed-off-by: Haoyang Li <[email protected]>

* Use contains instead for that case

Signed-off-by: Haoyang Li <[email protected]>

* add config to switch

Signed-off-by: Haoyang Li <[email protected]>

* Rewrite some rlike expression to StartsWith/EndsWith/Contains

Signed-off-by: Haoyang Li <[email protected]>

* clean up

Signed-off-by: Haoyang Li <[email protected]>

* wip

Signed-off-by: Haoyang Li <[email protected]>

* wip

Signed-off-by: Haoyang Li <[email protected]>

* add tests and config

Signed-off-by: Haoyang Li <[email protected]>

* support range filter

Signed-off-by: Haoyang Li <[email protected]>

---------

Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven authored May 6, 2024
1 parent 50822d5 commit cfc27b5
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 8 deletions.
44 changes: 43 additions & 1 deletion integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
else:
pytestmark = pytest.mark.regexp

_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True }
_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True ,
'spark.rapids.sql.rLikeRegexRewrite.enabled': 'new'}

def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')
Expand Down Expand Up @@ -444,6 +445,47 @@ def test_regexp_like():
'regexp_like(a, "a[bc]d")'),
conf=_regexp_conf)

@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0')
def test_regexp_rlike_rewrite_optimization():
gen = mk_str_gen('[abcd]{3,6}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'a',
'regexp_like(a, "(abcd)(.*)")',
'regexp_like(a, "abcd(.*)")',
'regexp_like(a, "(.*)(abcd)(.*)")',
'regexp_like(a, "^(abcd)(.*)")',
'regexp_like(a, "^abcd")',
'regexp_like(a, "(abcd)$")',
'regexp_like(a, ".*abcd$")',
'regexp_like(a, "^(abcd)$")',
'regexp_like(a, "^abcd$")',
'regexp_like(a, "ab(.*)cd")',
'regexp_like(a, "^^abcd")',
'regexp_like(a, "(.*)(.*)abcd")'),
conf=_regexp_conf)

@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0')
def test_regexp_rlike_rewrite_optimization_str_dig():
gen = mk_str_gen('([abcd]{3,6})?[0-9]{2,5}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'a',
'regexp_like(a, "[0-9]{4,}")',
'regexp_like(a, "abcd([0-9]{5})")'),
conf=_regexp_conf)

# [\\u4e00-\\u9fa5]+

@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0')
def test_regexp_rlike_rewrite_optimization_chinese():
gen = mk_str_gen('[0-9]{0,2}([英伟达]{0,3})?[a-z]{0,2}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'a',
'regexp_like(a, "[\\u4e00-\\u9fa5]+")'),
conf=_regexp_conf)

def test_regexp_replace_character_set_negated():
gen = mk_str_gen('[abcd]{0,3}[\r\n]{0,2}[abcd]{0,3}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,12 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.booleanConf
.createWithDefault(true)

val ENABLE_RLIKE_REGEX_REWRITE = conf("spark.rapids.sql.rLikeRegexRewrite.enabled")
.doc("Enable the optimization to rewrite rlike regex to contains in some cases.")
.internal()
.stringConf
.createWithDefault("new")

val ENABLE_GETJSONOBJECT_LEGACY = conf("spark.rapids.sql.getJsonObject.legacy.enabled")
.doc("When set to true, the get_json_object function will use the legacy implementation " +
"on the GPU. The legacy implementation is faster than the current implementation, but " +
Expand Down Expand Up @@ -2624,6 +2630,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val isTieredProjectEnabled: Boolean = get(ENABLE_TIERED_PROJECT)

lazy val isRlikeRegexRewriteEnabled: String = get(ENABLE_RLIKE_REGEX_REWRITE)

lazy val isLegacyGetJsonObjectEnabled: Boolean = get(ENABLE_GETJSONOBJECT_LEGACY)

lazy val isExpandPreprojectEnabled: Boolean = get(ENABLE_EXPAND_PREPROJECT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, Co
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.CastStrings
import com.nvidia.spark.rapids.jni.{CastStrings, StringDigitsPattern}
import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -168,7 +168,7 @@ case class GpuStartsWith(left: Expression, right: Expression)

override def toString: String = s"gpustartswith($left, $right)"

def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector =
def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector =
lhs.getBase.startsWith(rhs.getBase)

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
Expand Down Expand Up @@ -1054,22 +1054,144 @@ object GpuRegExpUtils {

}

sealed trait RegexprPart
object RegexprPart {
case object Start extends RegexprPart // ^
case object End extends RegexprPart // $
case object Wildcard extends RegexprPart // .* or (.*)
case class Digits(from: Int, to: Int) extends RegexprPart // [0-9]{a, b}
case object Chinese extends RegexprPart // Chinese characters [\u4e00-\u9fa5]+
case class Fixstring(name: String) extends RegexprPart // normal string without special characters
case class Regexpr(value: String) extends RegexprPart // other strings
}

class GpuRLikeMeta(
expr: RLike,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule) extends BinaryExprMeta[RLike](expr, conf, parent, rule) {

import RegexprPart._

private var originalPattern: String = ""
private var pattern: Option[String] = None

val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}', '\\' ,'(', ')')

def isSimplePattern(pat: String): Boolean = {
pat.size > 0 && pat.forall(c => !specialChars.contains(c))
}

def parseRegexToParts(pat: String): List[RegexprPart] = {
pat match {
case "" =>
List()
case s if s.startsWith("^") =>
Start :: parseRegexToParts(s.substring(1))
case s if s.endsWith("$") =>
parseRegexToParts(s.substring(0, s.length - 1)) :+ End
case s if s.startsWith(".*") =>
Wildcard :: parseRegexToParts(s.substring(2))
case s if s.endsWith(".*") =>
parseRegexToParts(s.substring(0, s.length - 2)) :+ Wildcard
case s if s.startsWith("(.*)") =>
Wildcard :: parseRegexToParts(s.substring(4))
case s if s.endsWith("(.*)") =>
parseRegexToParts(s.substring(0, s.length - 4)) :+ Wildcard
case s if s.startsWith("[\u4e00-\u9fa5]+") =>
parseRegexToParts(s.substring(0, s.length - 6)) :+ Chinese
case s if s.endsWith("([0-9]{5})") =>
parseRegexToParts(s.substring(0, s.length - 10)) :+ Digits(5, 5)
case s if s.endsWith("[0-9]{4,}") =>
parseRegexToParts(s.substring(0, s.length - 9)) :+ Digits(4, -1)
case s if s.startsWith("(") && s.endsWith(")") =>
parseRegexToParts(s.substring(1, s.length - 1))
case s if isSimplePattern(s) =>
Fixstring(s) :: List()
case s =>
Regexpr(s) :: List()
}
}

def optimizeSimplePattern(rhs: Expression, lhs: Expression, parts: List[RegexprPart]):
GpuExpression = {
parts match {
case Wildcard :: rest => {
optimizeSimplePattern(rhs, lhs, rest)
}
case Start :: Wildcard :: List(End) => {
GpuEqualTo(lhs, rhs)
}
case Start :: Fixstring(s) :: rest
if rest.forall(_ == Wildcard) || rest == List() => {
GpuStartsWith(lhs, GpuLiteral(s, StringType))
}
case Fixstring(s) :: List(End) => {
GpuEndsWith(lhs, GpuLiteral(s, StringType))
}
case Chinese :: rest
if rest == List() || rest.forall(_ == Wildcard) => {
// println(s"!!!GpuStringDigits chinese")
GpuStringDigits(lhs, GpuLiteral("", StringType), 1, 19968, 40869)
}
case Digits(from, _) :: rest
if rest == List() || rest.forall(_ == Wildcard) => {
// println(s"!!!GpuStringDigits1")
GpuStringDigits(lhs, GpuLiteral("", StringType), from, 48, 57)
}
case Fixstring(s) :: Digits(from, _) :: rest
if rest == List() || rest.forall(_ == Wildcard) => {
// println(s"!!!GpuStringDigits2")
GpuStringDigits(lhs, GpuLiteral(s, StringType), from, 48, 57)
}
case Fixstring(s) :: rest
if rest == List() || rest.forall(_ == Wildcard) => {
GpuContains(lhs, GpuLiteral(s, StringType))
}
case _ => {
val patternStr = pattern.getOrElse(throw new IllegalStateException(
"Expression has not been tagged with cuDF regex pattern"))
GpuRLike(lhs, rhs, patternStr)
}
}
}

def optimizeSimplePatternLegancy(rhs: Expression, lhs: Expression, parts: List[RegexprPart]):
GpuExpression = {
parts match {
case Wildcard :: rest => {
optimizeSimplePattern(rhs, lhs, rest)
}
case Start :: Wildcard :: List(End) => {
GpuEqualTo(lhs, rhs)
}
case Start :: Fixstring(s) :: rest
if rest.forall(_ == Wildcard) || rest == List() => {
GpuStartsWith(lhs, GpuLiteral(s, StringType))
}
case Fixstring(s) :: List(End) => {
GpuEndsWith(lhs, GpuLiteral(s, StringType))
}
case Fixstring(s) :: rest
if rest == List() || rest.forall(_ == Wildcard) => {
GpuContains(lhs, GpuLiteral(s, StringType))
}
case _ => {
val patternStr = pattern.getOrElse(throw new IllegalStateException(
"Expression has not been tagged with cuDF regex pattern"))
GpuRLike(lhs, rhs, patternStr)
}
}
}

override def tagExprForGpu(): Unit = {
GpuRegExpUtils.tagForRegExpEnabled(this)
expr.right match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
// verify that we support this regex and can transpile it to cuDF format
val (transpiledAST, _) =
new CudfRegexTranspiler(RegexFindMode).getTranspiledAST(str.toString, None, None)
originalPattern = str.toString
val (transpiledAST, _) = new CudfRegexTranspiler(RegexFindMode)
.getTranspiledAST(originalPattern, None, None)
GpuRegExpUtils.validateRegExpComplexity(this, transpiledAST)
pattern = Some(transpiledAST.toRegexString)
} catch {
Expand All @@ -1082,11 +1204,45 @@ class GpuRLikeMeta(
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuRLike(lhs, rhs, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
if (conf.isRlikeRegexRewriteEnabled == "new") {
// println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}")
// if the pattern can be converted to a startswith or endswith pattern, we can use
// GpuStartsWith, GpuEndsWith or GpuContains instead to get better performance
val parts = parseRegexToParts(originalPattern)
optimizeSimplePattern(rhs, lhs, parts)
} else if (conf.isRlikeRegexRewriteEnabled == "legacy") {
// println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}")
// if the pattern can be converted to a startswith or endswith pattern, we can use
// GpuStartsWith, GpuEndsWith or GpuContains instead to get better performance
val parts = parseRegexToParts(originalPattern)
optimizeSimplePatternLegancy(rhs, lhs, parts)
} else {
// println(s"!!!GpuRLike: ${conf.isRlikeRegexRewriteEnabled}")
val patternStr = pattern.getOrElse(throw new IllegalStateException(
"Expression has not been tagged with cuDF regex pattern"))
GpuRLike(lhs, rhs, patternStr)
}
}
}

case class GpuStringDigits(left: Expression, right: Expression, from: Int, start: Int, end: Int)
extends GpuBinaryExpressionArgsAnyScalar with ImplicitCastInputTypes with NullIntolerant {

override def dataType: DataType = BooleanType

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
StringDigitsPattern.stringDigitsPattern(lhs.getBase, rhs.getBase, from, start, end)
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs =>
doColumnar(expandedLhs, rhs)
}
}
}

case class GpuRLike(left: Expression, right: Expression, pattern: String)
extends GpuBinaryExpressionArgsAnyScalar
with ImplicitCastInputTypes
Expand Down

0 comments on commit cfc27b5

Please sign in to comment.