Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transpile simple choice-type regular expressions into lists of choices to use with string replace multi #7967

Merged
merged 10 commits into from
Apr 3, 2023
59 changes: 56 additions & 3 deletions integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
else:
pytestmark = pytest.mark.regexp

_regexp_conf = { 'spark.rapids.sql.regexp.enabled': 'true' }
_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True }

def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')
Expand Down Expand Up @@ -804,6 +804,59 @@ def test_regexp_replace_unicode_support():
),
conf=_regexp_conf)

@allow_non_gpu('ProjectExec', 'RegExpReplace')
def test_regexp_replace_fallback():
gen = mk_str_gen('[abcdef]{0,2}')

conf = { 'spark.rapids.sql.regexp.enabled': 'false' }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, here and other places, could use typed constants

Suggested change
conf = { 'spark.rapids.sql.regexp.enabled': 'false' }
conf = { 'spark.rapids.sql.regexp.enabled': False }


assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "[a-z]+", "PROD")',
'REGEXP_REPLACE(a, "aa", "PROD")',
),
cpu_fallback_class_name='RegExpReplace',
conf=conf
)

@pytest.mark.parametrize("regexp_enabled", ['true', 'false'])
def test_regexp_replace_simple(regexp_enabled):
gen = mk_str_gen('[abcdef]{0,2}')

conf = { 'spark.rapids.sql.regexp.enabled': regexp_enabled }

assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "aa", "PROD")',
'REGEXP_REPLACE(a, "ab", "PROD")',
'REGEXP_REPLACE(a, "ae", "PROD")',
'REGEXP_REPLACE(a, "bc", "PROD")',
'REGEXP_REPLACE(a, "fa", "PROD")'
),
conf=conf
)

@pytest.mark.parametrize("regexp_enabled", ['true', 'false'])
def test_regexp_replace_multi_optimization(regexp_enabled):
gen = mk_str_gen('[abcdef]{0,2}')

conf = { 'spark.rapids.sql.regexp.enabled': regexp_enabled }

assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "aa|bb", "PROD")',
'REGEXP_REPLACE(a, "(aa)|(bb)", "PROD")',
'REGEXP_REPLACE(a, "aa|bb|cc", "PROD")',
'REGEXP_REPLACE(a, "(aa)|(bb)|(cc)", "PROD")',
NVnavkumar marked this conversation as resolved.
Show resolved Hide resolved
'REGEXP_REPLACE(a, "aa|bb|cc|dd", "PROD")',
'REGEXP_REPLACE(a, "(aa|bb)|(cc|dd)", "PROD")',
'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee", "PROD")',
'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee|ff", "PROD")'
),
conf=conf
)


def test_regexp_split_unicode_support():
data_gen = mk_str_gen('([bf]o{0,2}青){1,7}') \
.with_special_case('boo青and青foo')
Expand Down Expand Up @@ -836,7 +889,7 @@ def test_regexp_memory_fallback():
),
cpu_fallback_class_name='RLike',
conf={
'spark.rapids.sql.regexp.enabled': 'true',
'spark.rapids.sql.regexp.enabled': True,
'spark.rapids.sql.regexp.maxStateMemoryBytes': '10',
'spark.rapids.sql.batchSizeBytes': '20' # 1 row in the batch
}
Expand All @@ -858,7 +911,7 @@ def test_regexp_memory_ok():
'a rlike "1|2|3|4|5|6"'
),
conf={
'spark.rapids.sql.regexp.enabled': 'true',
'spark.rapids.sql.regexp.enabled': True,
'spark.rapids.sql.regexp.maxStateMemoryBytes': '12',
'spark.rapids.sql.batchSizeBytes': '20' # 1 row in the batch
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,11 @@ object GpuOverrides extends Logging {
lit.value == null
}

def isSupportedStringReplacePattern(strLit: String): Boolean = {
// check for regex special characters, except for \u0000 which we can support
!regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more readable in a conjunctive form, but it is not part of your PR, so very optional:

Suggested change
!regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern))
!regexList.exists(pattern => pattern != "\u0000" && strLit.contains(pattern))

}

def isSupportedStringReplacePattern(exp: Expression): Boolean = {
extractLit(exp) match {
case Some(Literal(null, _)) => false
Expand All @@ -602,7 +607,7 @@ object GpuOverrides extends Logging {
false
} else {
// check for regex special characters, except for \u0000 which we can support
!regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern))
isSupportedStringReplacePattern(strLit)
}
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@
package com.nvidia.spark.rapids

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

trait GpuRegExpReplaceOpt extends Serializable

@SerialVersionUID(100L)
object GpuRegExpStringReplace extends GpuRegExpReplaceOpt

@SerialVersionUID(100L)
object GpuRegExpStringReplaceMulti extends GpuRegExpReplaceOpt

class GpuRegExpReplaceMeta(
expr: RegExpReplace,
conf: RapidsConf,
Expand All @@ -30,37 +38,44 @@ class GpuRegExpReplaceMeta(
private var javaPattern: Option[String] = None
private var cudfPattern: Option[String] = None
private var replacement: Option[String] = None
private var canUseGpuStringReplace = false
private var searchList: Option[Seq[String]] = None
private var replaceOpt: Option[GpuRegExpReplaceOpt] = None
private var containsBackref: Boolean = false

override def tagExprForGpu(): Unit = {
GpuRegExpUtils.tagForRegExpEnabled(this)
replacement = expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null => Some(s.toString)
case _ => None
}

expr.regexp match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
canUseGpuStringReplace = true
} else {
try {
javaPattern = Some(s.toString())
val (pat, repl) =
new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None,
replacement)
GpuRegExpUtils.validateRegExpComplexity(this, pat)
cudfPattern = Some(pat.toRegexString)
repl.map { r => GpuRegExpUtils.backrefConversion(r.toRegexString) }.foreach {
case (hasBackref, convertedRep) =>
containsBackref = hasBackref
replacement = Some(GpuRegExpUtils.unescapeReplaceString(convertedRep))
}
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
javaPattern = Some(s.toString())
try {
val (pat, repl) =
new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None,
replacement)
repl.map { r => GpuRegExpUtils.backrefConversion(r.toRegexString) }.foreach {
case (hasBackref, convertedRep) =>
containsBackref = hasBackref
replacement = Some(GpuRegExpUtils.unescapeReplaceString(convertedRep))
}
if (!containsBackref && GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
replaceOpt = Some(GpuRegExpStringReplace)
} else {
searchList = GpuRegExpUtils.getChoicesFromRegex(pat)
searchList match {
case Some(_) if !containsBackref =>
replaceOpt = Some(GpuRegExpStringReplaceMulti)
case _ =>
GpuRegExpUtils.tagForRegExpEnabled(this)
GpuRegExpUtils.validateRegExpComplexity(this, pat)
cudfPattern = Some(pat.toRegexString)
}
}
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
}

case _ =>
Expand All @@ -82,19 +97,27 @@ class GpuRegExpReplaceMeta(
// ignore the pos expression which must be a literal 1 after tagging check
require(childExprs.length == 4,
s"Unexpected child count for RegExpReplace: ${childExprs.length}")
if (canUseGpuStringReplace) {
GpuStringReplace(lhs, regexp, rep)
} else {
(javaPattern, cudfPattern, replacement) match {
case (Some(javaPattern), Some(cudfPattern), Some(cudfReplacement)) =>
if (containsBackref) {
GpuRegExpReplaceWithBackref(lhs, regexp, rep)(cudfPattern, cudfReplacement)
} else {
GpuRegExpReplace(lhs, regexp, rep)(javaPattern, cudfPattern, cudfReplacement)
}
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
replaceOpt match {
case None =>
(javaPattern, cudfPattern, replacement) match {
case (Some(javaPattern), Some(cudfPattern), Some(cudfReplacement)) =>
if (containsBackref) {
GpuRegExpReplaceWithBackref(lhs, regexp, rep)(cudfPattern, cudfReplacement)
} else {
GpuRegExpReplace(lhs, regexp, rep)(javaPattern, cudfPattern, cudfReplacement,
None, None)
}
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
case _ =>
(javaPattern, replacement) match {
case (Some(javaPattern), Some(replacement)) =>
GpuRegExpReplace(lhs, regexp, rep)(javaPattern, javaPattern, replacement,
searchList, replaceOpt)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Loading