Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regular expression support to string_split #4714

Merged
merged 16 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ The following Apache Spark regular expression functions and expressions are supp
- `regexp_extract`
- `regexp_like`
- `regexp_replace`
- `string_split`

Regular expression evaluation on the GPU can potentially have high memory overhead and cause out-of-memory errors. To
disable regular expressions on the GPU, set `spark.rapids.sql.regexp.enabled=false`.
Expand All @@ -535,6 +536,7 @@ Here are some examples of regular expression patterns that are not supported on
- Line anchor `$`
- String anchor `\Z`
- String anchor `\z` is not supported by `regexp_replace`
- Line and string anchors are not supported by `string_split`
- Non-digit character class `\D`
- Non-word character class `\W`
- Word and non-word boundaries, `\b` and `\B`
Expand Down
108 changes: 105 additions & 3 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, \
assert_gpu_sql_fallback_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error, \
assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime
from data_gen import *
from marks import *
Expand All @@ -25,15 +27,115 @@
def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

def test_split():
def test_split_no_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
delim = '_'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB")',
'split(a, "C")',
'split(a, "_")'))

def test_split_negative_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", -1)',
'split(a, "C", -2)',
'split(a, "_", -999)'))

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_zero_limit_fallback():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", 0)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_one_limit_fallback():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", 1)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

def test_split_positive_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", 2)',
'split(a, "C", 3)',
'split(a, "_", 999)'))

def test_split_re_negative_limit():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]", -1)',
'split(a, "[o:]", -1)',
'split(a, "[^:]", -1)',
'split(a, "[^o]", -1)',
'split(a, "[o]{1,2}", -1)',
'split(a, "[bf]", -1)',
'split(a, "[o]", -2)'))

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_re_zero_limit_fallback():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]", 0)',
'split(a, "[o:]", 0)',
'split(a, "[o]", 0)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

# https://github.com/NVIDIA/spark-rapids/issues/4720
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_re_one_limit_fallback():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]", 1)',
'split(a, "[o:]", 1)',
'split(a, "[o]", 1)'),
exist_classes= "ProjectExec",
non_exist_classes= "GpuProjectExec")

def test_split_re_positive_limit():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]", 2)',
'split(a, "[o:]", 5)',
'split(a, "[^:]", 2)',
'split(a, "[^o]", 55)',
'split(a, "[o]{1,2}", 999)',
'split(a, "[bf]", 2)',
'split(a, "[o]", 5)'))

def test_split_re_no_limit():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "[:]")',
'split(a, "[o:]")',
'split(a, "[^:]")',
'split(a, "[^o]")',
'split(a, "[o]{1,2}")',
'split(a, "[bf]")',
'split(a, "[o]")'))

@pytest.mark.parametrize('data_gen,delim', [(mk_str_gen('([ABC]{0,3}_?){0,7}'), '_'),
(mk_str_gen('([MNP_]{0,3}\\.?){0,5}'), '.'),
(mk_str_gen('([123]{0,3}\\^?){0,5}'), '^')], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta}
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException, TernaryExprMeta}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
Expand All @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
Expand All @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
Expand All @@ -40,7 +40,7 @@ class GpuRegExpReplaceMeta(
// use GpuStringReplace
} else {
try {
pattern = Some(new CudfRegexTranspiler(replace = true).transpile(s.toString))
pattern = Some(new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString))
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,19 @@ object RegexParser {
}
}

sealed trait RegexMode
object RegexFindMode extends RegexMode
object RegexReplaceMode extends RegexMode
object RegexSplitMode extends RegexMode

/**
* Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception
* if this is not possible.
*
* @param replace True if performing a replacement (regexp_replace), false
* if matching only (rlike)
*/
class CudfRegexTranspiler(replace: Boolean) {
class CudfRegexTranspiler(mode: RegexMode) {

// cuDF throws a "nothing to repeat" exception for many of the edge cases that are
// rejected by the transpiler
Expand Down Expand Up @@ -472,6 +477,8 @@ class CudfRegexTranspiler(replace: Boolean) {
case '$' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4533
throw new RegexUnsupportedException("line anchor $ is not supported")
case '^' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("line anchor ^ is not supported in split mode")
case _ =>
regex
}
Expand Down Expand Up @@ -506,8 +513,14 @@ class CudfRegexTranspiler(replace: Boolean) {
case 's' | 'S' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4528
throw new RegexUnsupportedException("whitespace classes are not supported")
case 'A' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\A is not supported in split mode")
case 'Z' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\Z is not supported in split mode")
case 'z' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\z is not supported in split mode")
case 'z' =>
if (replace) {
if (mode == RegexReplaceMode) {
// see https://github.com/NVIDIA/spark-rapids/issues/4425
throw new RegexUnsupportedException(
"string anchor \\z is not supported in replace mode")
Expand Down Expand Up @@ -607,7 +620,7 @@ class CudfRegexTranspiler(replace: Boolean) {
RegexSequence(parts.map(rewrite))

case RegexRepetition(base, quantifier) => (base, quantifier) match {
case (_, SimpleQuantifier(ch)) if replace && "?*".contains(ch) =>
case (_, SimpleQuantifier(ch)) if mode == RegexReplaceMode && "?*".contains(ch) =>
// example: pattern " ?", input "] b[", replace with "X":
// java: X]XXbX[X
// cuDF: XXXX] b[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ trait GpuConditionalExpression extends ComplexTypeMergingExpression with GpuExpr
// predicate boolean array results in the two T values mapping to
// indices 0 and 1, respectively.

// [F, null, T, F, T]
// [0, 0, 0, 1, 1]
[ 0, 1 ]

revans2 marked this conversation as resolved.
Show resolved Hide resolved
val prefixSumExclusive = withResource(boolToInt(predicate)) { boolsAsInts =>
boolsAsInts.scan(
ScanAggregation.sum(),
Expand Down
Loading