Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transpile simple choice-type regular expressions into lists of choices to use with string replace multi #7967

Merged
merged 10 commits into from
Apr 3, 2023
52 changes: 52 additions & 0 deletions integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,58 @@ def test_regexp_replace_unicode_support():
),
conf=_regexp_conf)

@allow_non_gpu('ProjectExec', 'RegExpReplace')
def test_regexp_replace_fallback():
gen = mk_str_gen('[abcdef]{0,2}')

conf = { 'spark.rapids.sql.regexp.enabled': 'false' }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, here and other places, could use typed constants

Suggested change
conf = { 'spark.rapids.sql.regexp.enabled': 'false' }
conf = { 'spark.rapids.sql.regexp.enabled': False }


assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "[a-z]+", "PROD")',
'REGEXP_REPLACE(a, "aa", "PROD")',
),
cpu_fallback_class_name='RegExpReplace',
conf=conf
)

@pytest.mark.parametrize("regexp_enabled", ['true', 'false'])
def test_regexp_replace_simple(regexp_enabled):
gen = mk_str_gen('[abcdef]{0,2}')

conf = { 'spark.rapids.sql.regexp.enabled': regexp_enabled }

assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "aa", "PROD")',
'REGEXP_REPLACE(a, "ab", "PROD")',
'REGEXP_REPLACE(a, "ae", "PROD")',
'REGEXP_REPLACE(a, "bc", "PROD")',
'REGEXP_REPLACE(a, "fa", "PROD")'
),
conf=conf
)

@pytest.mark.parametrize("regexp_enabled", ['true', 'false'])
def test_regexp_replace_multi_optimization(regexp_enabled):
gen = mk_str_gen('[abcdef]{0,2}')

conf = { 'spark.rapids.sql.regexp.enabled': regexp_enabled }

assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "aa|bb", "PROD")',
'REGEXP_REPLACE(a, "(aa)|(bb)", "PROD")',
'REGEXP_REPLACE(a, "aa|bb|cc", "PROD")',
'REGEXP_REPLACE(a, "(aa)|(bb)|(cc)", "PROD")',
NVnavkumar marked this conversation as resolved.
Show resolved Hide resolved
'REGEXP_REPLACE(a, "aa|bb|cc|dd", "PROD")',
'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee", "PROD")',
'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee|ff", "PROD")'
),
conf=conf
)


def test_regexp_split_unicode_support():
data_gen = mk_str_gen('([bf]o{0,2}青){1,7}') \
.with_special_case('boo青and青foo')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,11 @@ object GpuOverrides extends Logging {
lit.value == null
}

def isSupportedStringReplacePattern(strLit: String): Boolean = {
// check for regex special characters, except for \u0000 which we can support
!regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more readable in a conjunctive form, but it is not part of your PR, so very optional:

Suggested change
!regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern))
!regexList.exists(pattern => pattern != "\u0000" && strLit.contains(pattern))

}

def isSupportedStringReplacePattern(exp: Expression): Boolean = {
extractLit(exp) match {
case Some(Literal(null, _)) => false
Expand All @@ -602,7 +607,7 @@ object GpuOverrides extends Logging {
false
} else {
// check for regex special characters, except for \u0000 which we can support
!regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern))
isSupportedStringReplacePattern(strLit)
}
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@
package com.nvidia.spark.rapids

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

trait GpuRegExpReplaceOpt extends Serializable

@SerialVersionUID(100L)
object GpuRegExpStringReplace extends GpuRegExpReplaceOpt

@SerialVersionUID(100L)
object GpuRegExpStringReplaceMulti extends GpuRegExpReplaceOpt

class GpuRegExpReplaceMeta(
expr: RegExpReplace,
conf: RapidsConf,
Expand All @@ -30,37 +38,44 @@ class GpuRegExpReplaceMeta(
private var javaPattern: Option[String] = None
private var cudfPattern: Option[String] = None
private var replacement: Option[String] = None
private var canUseGpuStringReplace = false
private var searchList: Option[Seq[String]] = None
private var replaceOpt: Option[GpuRegExpReplaceOpt] = None
private var containsBackref: Boolean = false

override def tagExprForGpu(): Unit = {
GpuRegExpUtils.tagForRegExpEnabled(this)
replacement = expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null => Some(s.toString)
case _ => None
}

expr.regexp match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
canUseGpuStringReplace = true
} else {
try {
javaPattern = Some(s.toString())
val (pat, repl) =
new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None,
replacement)
GpuRegExpUtils.validateRegExpComplexity(this, pat)
cudfPattern = Some(pat.toRegexString)
repl.map { r => GpuRegExpUtils.backrefConversion(r.toRegexString) }.foreach {
case (hasBackref, convertedRep) =>
containsBackref = hasBackref
replacement = Some(GpuRegExpUtils.unescapeReplaceString(convertedRep))
}
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
javaPattern = Some(s.toString())
try {
val (pat, repl) =
new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None,
replacement)
repl.map { r => GpuRegExpUtils.backrefConversion(r.toRegexString) }.foreach {
case (hasBackref, convertedRep) =>
containsBackref = hasBackref
replacement = Some(GpuRegExpUtils.unescapeReplaceString(convertedRep))
}
if (!containsBackref && GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
replaceOpt = Some(GpuRegExpStringReplace)
} else {
searchList = GpuRegExpUtils.getChoicesFromRegex(pat)
searchList match {
case Some(_) if !containsBackref =>
replaceOpt = Some(GpuRegExpStringReplaceMulti)
case _ =>
GpuRegExpUtils.tagForRegExpEnabled(this)
GpuRegExpUtils.validateRegExpComplexity(this, pat)
cudfPattern = Some(pat.toRegexString)
}
}
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
}

case _ =>
Expand All @@ -82,19 +97,27 @@ class GpuRegExpReplaceMeta(
// ignore the pos expression which must be a literal 1 after tagging check
require(childExprs.length == 4,
s"Unexpected child count for RegExpReplace: ${childExprs.length}")
if (canUseGpuStringReplace) {
GpuStringReplace(lhs, regexp, rep)
} else {
(javaPattern, cudfPattern, replacement) match {
case (Some(javaPattern), Some(cudfPattern), Some(cudfReplacement)) =>
if (containsBackref) {
GpuRegExpReplaceWithBackref(lhs, cudfPattern, cudfReplacement)
} else {
GpuRegExpReplace(lhs, regexp, rep, javaPattern, cudfPattern, cudfReplacement)
}
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
replaceOpt match {
case None =>
(javaPattern, cudfPattern, replacement) match {
case (Some(javaPattern), Some(cudfPattern), Some(cudfReplacement)) =>
if (containsBackref) {
GpuRegExpReplaceWithBackref(lhs, cudfPattern, cudfReplacement)
} else {
GpuRegExpReplace(lhs, regexp, rep, javaPattern, cudfPattern, cudfReplacement,
None, None)
}
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
case _ =>
(javaPattern, replacement) match {
case (Some(javaPattern), Some(replacement)) =>
GpuRegExpReplace(lhs, regexp, rep, javaPattern, javaPattern, replacement,
searchList, replaceOpt)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -742,11 +742,39 @@ case class GpuStringRepeat(input: Expression, repeatTimes: Expression)

}

trait HasGpuStringReplace extends Arm {
def doStringReplace(
strExpr: GpuColumnVector,
searchExpr: GpuScalar,
replaceExpr: GpuScalar): ColumnVector = {
// When search or replace string is null, return all nulls like the CPU does.
if (!searchExpr.isValid || !replaceExpr.isValid) {
GpuColumnVector.columnVectorFromNull(strExpr.getRowCount.toInt, StringType)
} else if (searchExpr.getValue.asInstanceOf[UTF8String].numChars() == 0) {
// Return original string if search string is empty
strExpr.getBase.asStrings()
} else {
strExpr.getBase.stringReplace(searchExpr.getBase, replaceExpr.getBase)
}
}

def doStringReplaceMulti(
strExpr: GpuColumnVector,
search: Seq[String],
replacement: String): ColumnVector = {
withResource(ColumnVector.fromStrings(search: _*)) { targets =>
withResource(ColumnVector.fromStrings(replacement)) { repls =>
strExpr.getBase.stringReplace(targets, repls)
}
}
}
}

case class GpuStringReplace(
srcExpr: Expression,
searchExpr: Expression,
replaceExpr: Expression)
extends GpuTernaryExpression with ImplicitCastInputTypes {
extends GpuTernaryExpression with ImplicitCastInputTypes with HasGpuStringReplace {

override def dataType: DataType = srcExpr.dataType

Expand Down Expand Up @@ -794,15 +822,7 @@ case class GpuStringReplace(
strExpr: GpuColumnVector,
searchExpr: GpuScalar,
replaceExpr: GpuScalar): ColumnVector = {
// When search or replace string is null, return all nulls like the CPU does.
if (!searchExpr.isValid || !replaceExpr.isValid) {
GpuColumnVector.columnVectorFromNull(strExpr.getRowCount.toInt, StringType)
} else if (searchExpr.getValue.asInstanceOf[UTF8String].numChars() == 0) {
// Return original string if search string is empty
strExpr.getBase.asStrings()
} else {
strExpr.getBase.stringReplace(searchExpr.getBase, replaceExpr.getBase)
}
doStringReplace(strExpr, searchExpr, replaceExpr)
}

override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar,
Expand Down Expand Up @@ -996,6 +1016,40 @@ object GpuRegExpUtils {
countGroups(parseAST(pattern))
}

def getChoicesFromRegex(regex: RegexAST): Option[Seq[String]] = {
NVnavkumar marked this conversation as resolved.
Show resolved Hide resolved
regex match {
case RegexGroup(_, t, None) =>
if (GpuOverrides.isSupportedStringReplacePattern(t.toRegexString)) {
Some(Seq(t.toRegexString))
} else {
None
}
case RegexChoice(a, b) =>
getChoicesFromRegex(a) match {
case Some(la) =>
getChoicesFromRegex(b) match {
case Some(lb) => Some(la ++ lb)
case _ => None
}
case _ => None
}
case RegexSequence(parts) =>
if (GpuOverrides.isSupportedStringReplacePattern(regex.toRegexString)) {
Some(Seq(regex.toRegexString))
} else {
parts.foldLeft(Some(Seq[String]()): Option[Seq[String]]) { (m: Option[Seq[String]], r) =>
getChoicesFromRegex(r) match {
case Some(l) => m.map(_ ++ l)
case _ => None
}
}
}
case _ =>
None
}
}


}

class GpuRLikeMeta(
Expand Down Expand Up @@ -1114,8 +1168,10 @@ case class GpuRegExpReplace(
replaceExpr: Expression,
javaRegexpPattern: String,
cudfRegexPattern: String,
cudfReplacementString: String)
extends GpuRegExpTernaryBase with ImplicitCastInputTypes {
cudfReplacementString: String,
searchList: Option[Seq[String]],
replaceOpt: Option[GpuRegExpReplaceOpt])
extends GpuRegExpTernaryBase with ImplicitCastInputTypes with HasGpuStringReplace {

override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)

Expand All @@ -1127,7 +1183,7 @@ case class GpuRegExpReplace(
cudfRegexPattern: String, cudfReplacementString: String) = {

this(srcExpr, searchExpr, GpuLiteral("", StringType), javaRegexpPattern,
cudfRegexPattern, cudfReplacementString)
cudfRegexPattern, cudfReplacementString, None, None)
}

override def doColumnar(
Expand All @@ -1137,28 +1193,41 @@ case class GpuRegExpReplace(
// For empty strings and a regex containing only a zero-match repetition,
// the behavior in some versions of Spark is different.
// see https://github.com/NVIDIA/spark-rapids/issues/5456
val prog = new RegexProgram(cudfRegexPattern, CaptureGroups.NON_CAPTURE)
if (SparkShimImpl.reproduceEmptyStringBug &&
GpuRegExpUtils.isEmptyRepetition(javaRegexpPattern)) {
val isEmpty = withResource(strExpr.getBase.getCharLengths) { len =>
withResource(Scalar.fromInt(0)) { zero =>
len.equalTo(zero)
replaceOpt match {
case Some(GpuRegExpStringReplace) =>
doStringReplace(strExpr, searchExpr, replaceExpr)
case Some(GpuRegExpStringReplaceMulti) =>
searchList match {
case Some(searches) =>
doStringReplaceMulti(strExpr, searches, cudfReplacementString)
case _ =>
throw new IllegalStateException("Need a replace")
}
}
withResource(isEmpty) { _ =>
withResource(GpuScalar.from("", DataTypes.StringType)) { emptyString =>
withResource(GpuScalar.from(cudfReplacementString, DataTypes.StringType)) { rep =>
withResource(strExpr.getBase.replaceRegex(prog, rep)) { replacement =>
isEmpty.ifElse(emptyString, replacement)
case _ =>
val prog = new RegexProgram(cudfRegexPattern, CaptureGroups.NON_CAPTURE)
if (SparkShimImpl.reproduceEmptyStringBug &&
GpuRegExpUtils.isEmptyRepetition(javaRegexpPattern)) {
val isEmpty = withResource(strExpr.getBase.getCharLengths) { len =>
withResource(Scalar.fromInt(0)) { zero =>
len.equalTo(zero)
}
}
withResource(isEmpty) { _ =>
withResource(GpuScalar.from("", DataTypes.StringType)) { emptyString =>
withResource(GpuScalar.from(cudfReplacementString, DataTypes.StringType)) { rep =>
withResource(strExpr.getBase.replaceRegex(prog, rep)) { replacement =>
isEmpty.ifElse(emptyString, replacement)
}
}
}
}
} else {
withResource(Scalar.fromString(cudfReplacementString)) { rep =>
strExpr.getBase.replaceRegex(prog, rep)
}
}
}
} else {
withResource(Scalar.fromString(cudfReplacementString)) { rep =>
strExpr.getBase.replaceRegex(prog, rep)
}
}

}

}
Expand Down
Loading