Skip to content

Commit

Permalink
Add regular expression support to string_split (#4714)
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
andygrove authored Feb 14, 2022
1 parent 3f476b4 commit 3c48c96
Show file tree
Hide file tree
Showing 9 changed files with 275 additions and 75 deletions.
2 changes: 2 additions & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ The following Apache Spark regular expression functions and expressions are supp
- `regexp_extract`
- `regexp_like`
- `regexp_replace`
- `string_split`

Regular expression evaluation on the GPU can potentially have high memory overhead and cause out-of-memory errors. To
disable regular expressions on the GPU, set `spark.rapids.sql.regexp.enabled=false`.
Expand All @@ -535,6 +536,7 @@ Here are some examples of regular expression patterns that are not supported on
- Line anchor `$`
- String anchor `\Z`
- String anchor `\z` is not supported by `regexp_replace`
- Line and string anchors are not supported by `string_split`
- Non-digit character class `\D`
- Non-word character class `\W`
- Word and non-word boundaries, `\b` and `\B`
Expand Down
108 changes: 105 additions & 3 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, \
assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error, \
assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime
from data_gen import *
from marks import *
Expand All @@ -25,15 +27,115 @@
def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

def test_split():
def test_split_no_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
delim = '_'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB")',
'split(a, "C")',
'split(a, "_")'))

def test_split_negative_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", -1)',
'split(a, "C", -2)',
'split(a, "_", -999)'))

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_zero_limit_fallback():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", 0)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_one_limit_fallback():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", 1)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

def test_split_positive_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", 2)',
'split(a, "C", 3)',
'split(a, "_", 999)'))

def test_split_re_negative_limit():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]", -1)',
'split(a, "[o:]", -1)',
'split(a, "[^:]", -1)',
'split(a, "[^o]", -1)',
'split(a, "[o]{1,2}", -1)',
'split(a, "[bf]", -1)',
'split(a, "[o]", -2)'))

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_re_zero_limit_fallback():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]", 0)',
'split(a, "[o:]", 0)',
'split(a, "[o]", 0)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_re_one_limit_fallback():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]", 1)',
'split(a, "[o:]", 1)',
'split(a, "[o]", 1)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

def test_split_re_positive_limit():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]", 2)',
'split(a, "[o:]", 5)',
'split(a, "[^:]", 2)',
'split(a, "[^o]", 55)',
'split(a, "[o]{1,2}", 999)',
'split(a, "[bf]", 2)',
'split(a, "[o]", 5)'))

def test_split_re_no_limit():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]")',
'split(a, "[o:]")',
'split(a, "[^:]")',
'split(a, "[^o]")',
'split(a, "[o]{1,2}")',
'split(a, "[bf]")',
'split(a, "[o]")'))

@pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'),
(mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'),
(mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta}
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException, TernaryExprMeta}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
Expand All @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
Expand All @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
Expand All @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,19 @@ object RegexParser {
}
}

sealed trait RegexMode
object RegexFindMode extends RegexMode
object RegexReplaceMode extends RegexMode
object RegexSplitMode extends RegexMode

/**
* Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception
* if this is not possible.
*
* @param replace True if performing a replacement (regexp_replace), false
* if matching only (rlike)
*/
class CudfRegexTranspiler(replace: Boolean) {
class CudfRegexTranspiler(mode: RegexMode) {

// cuDF throws a "nothing to repeat" exception for many of the edge cases that are
// rejected by the transpiler
Expand Down Expand Up @@ -472,6 +477,8 @@ class CudfRegexTranspiler(replace: Boolean) {
case '$' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4533
throw new RegexUnsupportedException("line anchor $ is not supported")
case '^' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("line anchor ^ is not supported in split mode")
case _ =>
regex
}
Expand Down Expand Up @@ -506,8 +513,14 @@ class CudfRegexTranspiler(replace: Boolean) {
case 's' | 'S' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4528
throw new RegexUnsupportedException("whitespace classes are not supported")
case 'A' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\A is not supported in split mode")
case 'Z' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\Z is not supported in split mode")
case 'z' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\z is not supported in split mode")
case 'z' =>
if (replace) {
if (mode == RegexReplaceMode) {
// see https://github.com/NVIDIA/spark-rapids/issues/4425
throw new RegexUnsupportedException(
"string anchor \\z is not supported in replace mode")
Expand Down Expand Up @@ -607,7 +620,7 @@ class CudfRegexTranspiler(replace: Boolean) {
RegexSequence(parts.map(rewrite))

case RegexRepetition(base, quantifier) => (base, quantifier) match {
case (_, SimpleQuantifier(ch)) if replace && "?*".contains(ch) =>
case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) =>
// example: pattern " ?", input "] b[", replace with "X":
// java: X]XXbX[X
// cuDF: XXXX] b[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ class GpuRLikeMeta(
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
// verify that we support this regex and can transpile it to cuDF format
pattern = Some(new CudfRegexTranspiler(replace = false).transpile(str.toString))
pattern = Some(new CudfRegexTranspiler(RegexFindMode).transpile(str.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down Expand Up @@ -981,7 +981,7 @@ class GpuRegExpExtractMeta(
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
val cudfRegexPattern = new CudfRegexTranspiler(replace = false)
val cudfRegexPattern = new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern)
pattern = Some(cudfRegexPattern)
numGroups = countGroups(new RegexParser(javaRegexpPattern).parse())
Expand Down Expand Up @@ -1324,51 +1324,69 @@ class GpuStringSplitMeta(
extends TernaryExprMeta[StringSplit](expr, conf, parent, rule) {
import GpuOverrides._

private var pattern: Option[String] = None
private var isRegExp = false

override def tagExprForGpu(): Unit = {
val regexp = extractLit(expr.regex)
if (regexp.isEmpty) {
willNotWorkOnGpu("only literal regexp values are supported")
} else {
val str = regexp.get.value.asInstanceOf[UTF8String]
if (str != null) {
if (RegexParser.isRegExpString(str.toString)) {
willNotWorkOnGpu("regular expressions are not supported yet")
}
if (str.numChars() == 0) {
willNotWorkOnGpu("An empty regex is not supported yet")
}
isRegExp = RegexParser.isRegExpString(str.toString)
if (isRegExp) {
try {
pattern = Some(new CudfRegexTranspiler(RegexSplitMode).transpile(str.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
}
} else {
pattern = Some(str.toString)
}
} else {
willNotWorkOnGpu("null regex is not supported yet")
}
}
if (!isLit(expr.limit)) {
willNotWorkOnGpu("only literal limit is supported")
extractLit(expr.limit) match {
case Some(Literal(n: Int, _)) =>
if (n == 0 || n == 1) {
// https://github.com/NVIDIA/spark-rapids/issues/4720
willNotWorkOnGpu("limit of 0 or 1 is not supported")
}
case _ =>
willNotWorkOnGpu("only literal limit is supported")
}
}
override def convertToGpu(
str: Expression,
regexp: Expression,
limit: Expression): GpuExpression =
GpuStringSplit(str, regexp, limit)
limit: Expression): GpuExpression = {
GpuStringSplit(str, regexp, limit, isRegExp, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
}
}

case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression)
extends GpuTernaryExpression with ImplicitCastInputTypes {
case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression,
isRegExp: Boolean, pattern: String)
extends GpuTernaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = ArrayType(StringType, containsNull = false)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def first: Expression = str
override def second: Expression = regex
override def third: Expression = limit

def this(exp: Expression, regex: Expression) = this(exp, regex, GpuLiteral(-1, IntegerType))

override def prettyName: String = "split"

override def doColumnar(str: GpuColumnVector, regex: GpuScalar,
limit: GpuScalar): ColumnVector = {
val intLimit = limit.getValue.asInstanceOf[Int]
str.getBase.stringSplitRecord(regex.getBase, intLimit)
str.getBase.stringSplitRecord(pattern, intLimit, isRegExp)
}

override def doColumnar(numRows: Int, val0: GpuScalar, val1: GpuScalar,
Expand Down
Loading

0 comments on commit 3c48c96

Please sign in to comment.