Skip to content

Commit

Permalink
Add work around for string split with empty input. (#11292)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Aug 5, 2024
1 parent 20bff54 commit d3dc496
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 8 deletions.
9 changes: 8 additions & 1 deletion integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ def test_make_array(data_gen):
'array(b, a, null, {}, {})'.format(s1, s2),
'array(array(b, a, null, {}, {}), array(a), array(null))'.format(s1, s2)))

@pytest.mark.parametrize('empty_type', all_empty_string_types)
def test_make_array_empty_input(empty_type):
data_gen = mk_empty_str_gen(empty_type)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'array(a)',
'array(a, b)'))

@pytest.mark.parametrize('data_gen', single_level_array_gens, ids=idfn)
def test_orderby_array_unique(data_gen):
Expand Down Expand Up @@ -795,4 +802,4 @@ def test_map_from_arrays_length_exception():
lambda spark: gen_df(spark, gen).selectExpr(
'map_from_arrays(array(1), a)').collect(),
conf={'spark.sql.mapKeyDedupPolicy':'EXCEPTION'},
error_message = "The key array and value array of MapData must have the same length")
error_message = "The key array and value array of MapData must have the same length")
21 changes: 19 additions & 2 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
from datetime import date, datetime, timedelta, timezone
from decimal import *
from enum import Enum
import math
from pyspark.context import SparkContext
from pyspark.sql import Row
Expand Down Expand Up @@ -744,8 +745,8 @@ def contains_ts(self):

class NullGen(DataGen):
"""Generate NullType values"""
def __init__(self):
super().__init__(NullType(), nullable=True)
def __init__(self, dt = NullType()):
super().__init__(dt, nullable=True)

def start(self, rand):
def make_null():
Expand Down Expand Up @@ -1044,6 +1045,22 @@ def gen_scalars_for_sql(data_gen, count, seed=None, force_no_nulls=False):
spark_type = data_gen.data_type
return (_convert_to_sql(spark_type, src.gen(force_no_nulls=force_no_nulls)) for i in range(0, count))

class EmptyStringType(Enum):
ALL_NULL = 1
ALL_EMPTY = 2
MIXED = 3

all_empty_string_types = EmptyStringType.__members__.values()

empty_string_gens_map = {
EmptyStringType.ALL_NULL : lambda: NullGen(StringType()),
EmptyStringType.ALL_EMPTY : lambda: StringGen("", nullable=False),
EmptyStringType.MIXED : lambda: StringGen("", nullable=True)
}

def mk_empty_str_gen(empty_type):
return empty_string_gens_map[empty_type]()

byte_gen = ByteGen()
short_gen = ShortGen()
int_gen = IntegerGen()
Expand Down
10 changes: 10 additions & 0 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,16 @@ def test_str_to_map_expr_with_all_regex_delimiters():
), conf={'spark.sql.mapKeyDedupPolicy': 'LAST_WIN'})


@pytest.mark.parametrize('empty_type', all_empty_string_types)
def test_str_to_map_input_all_empty(empty_type):
data_gen = mk_empty_str_gen(empty_type)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'str_to_map(a) as m0',
'str_to_map(a, ",") as m1',
'str_to_map(a, ",", ":") as m2'
), conf={'spark.sql.mapKeyDedupPolicy': 'LAST_WIN'})

@pytest.mark.skipif(not is_before_spark_330(),
reason="Only in Spark 3.1.1+ (< 3.3.0) + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
Expand Down
12 changes: 12 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@
def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

@pytest.mark.parametrize('empty_type', all_empty_string_types)
@pytest.mark.parametrize('num_splits', ['-1', '0', '1', '2'])
def test_split_input_all_empty(empty_type, num_splits):
data_gen = mk_empty_str_gen(empty_type)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", ' + num_splits + ')',
'split(a, "C", ' + num_splits + ')',
'split(a, ">>", ' + num_splits + ')',
'split(a, "_", ' + num_splits + ')'),
conf=_regexp_conf)

def test_split_no_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils
import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1791,6 +1792,22 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression,

override def doColumnar(str: GpuColumnVector, regex: GpuScalar,
limit: GpuScalar): ColumnVector = {
// TODO when https://github.com/rapidsai/cudf/issues/16453 is fixed remove this workaround
withResource(str.getBase.getData) { data =>
if (data == null || data.getLength <= 0) {
// An empty data for a string means that all of the inputs are either null or an empty
// string. CUDF will return null for all of these, but we should only do that for a
// null input. For all of the others it should be an empty string.
withResource(GpuScalar.from(null, dataType)) { nullArray =>
withResource(GpuScalar.from(
new GenericArrayData(Array(UTF8String.EMPTY_UTF8)), dataType)) { emptyStringArray =>
withResource(str.getBase.isNull) { retNull =>
return retNull.ifElse(nullArray, emptyStringArray)
}
}
}
}
}
limit.getValue.asInstanceOf[Int] match {
case 0 =>
// Same as splitting as many times as possible
Expand All @@ -1802,11 +1819,10 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression,
case 1 =>
// Short circuit GPU and just return a list containing the original input string
withResource(str.getBase.isNull) { isNull =>
withResource(GpuScalar.from(null, DataTypes.createArrayType(DataTypes.StringType))) {
nullStringList =>
withResource(ColumnVector.makeList(str.getBase)) { list =>
isNull.ifElse(nullStringList, list)
}
withResource(GpuScalar.from(null, dataType)) { nullStringList =>
withResource(ColumnVector.makeList(str.getBase)) { list =>
isNull.ifElse(nullStringList, list)
}
}
}
case n =>
Expand Down

0 comments on commit d3dc496

Please sign in to comment.