Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add work around for string split with empty input. #11292

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ def test_make_array(data_gen):
'array(b, a, null, {}, {})'.format(s1, s2),
'array(array(b, a, null, {}, {}), array(a), array(null))'.format(s1, s2)))

@pytest.mark.parametrize('empty_type', all_empty_string_types)
def test_make_array_empty_input(empty_type):
data_gen = mk_empty_str_gen(empty_type)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'array(a)',
'array(a, b)'))

@pytest.mark.parametrize('data_gen', single_level_array_gens, ids=idfn)
def test_orderby_array_unique(data_gen):
Expand Down Expand Up @@ -795,4 +802,4 @@ def test_map_from_arrays_length_exception():
lambda spark: gen_df(spark, gen).selectExpr(
'map_from_arrays(array(1), a)').collect(),
conf={'spark.sql.mapKeyDedupPolicy':'EXCEPTION'},
error_message = "The key array and value array of MapData must have the same length")
error_message = "The key array and value array of MapData must have the same length")
21 changes: 19 additions & 2 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
from datetime import date, datetime, timedelta, timezone
from decimal import *
from enum import Enum
import math
from pyspark.context import SparkContext
from pyspark.sql import Row
Expand Down Expand Up @@ -744,8 +745,8 @@ def contains_ts(self):

class NullGen(DataGen):
"""Generate NullType values"""
def __init__(self):
super().__init__(NullType(), nullable=True)
def __init__(self, dt = NullType()):
super().__init__(dt, nullable=True)

def start(self, rand):
def make_null():
Expand Down Expand Up @@ -1044,6 +1045,22 @@ def gen_scalars_for_sql(data_gen, count, seed=None, force_no_nulls=False):
spark_type = data_gen.data_type
return (_convert_to_sql(spark_type, src.gen(force_no_nulls=force_no_nulls)) for i in range(0, count))

class EmptyStringType(Enum):
ALL_NULL = 1
ALL_EMPTY = 2
MIXED = 3

all_empty_string_types = EmptyStringType.__members__.values()

empty_string_gens_map = {
EmptyStringType.ALL_NULL : lambda: NullGen(StringType()),
EmptyStringType.ALL_EMPTY : lambda: StringGen("", nullable=False),
EmptyStringType.MIXED : lambda: StringGen("", nullable=True)
}

def mk_empty_str_gen(empty_type):
return empty_string_gens_map[empty_type]()

byte_gen = ByteGen()
short_gen = ShortGen()
int_gen = IntegerGen()
Expand Down
10 changes: 10 additions & 0 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,16 @@ def test_str_to_map_expr_with_all_regex_delimiters():
), conf={'spark.sql.mapKeyDedupPolicy': 'LAST_WIN'})


@pytest.mark.parametrize('empty_type', all_empty_string_types)
def test_str_to_map_input_all_empty(empty_type):
data_gen = mk_empty_str_gen(empty_type)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'str_to_map(a) as m0',
'str_to_map(a, ",") as m1',
'str_to_map(a, ",", ":") as m2'
), conf={'spark.sql.mapKeyDedupPolicy': 'LAST_WIN'})

@pytest.mark.skipif(not is_before_spark_330(),
reason="Only in Spark 3.1.1+ (< 3.3.0) + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
Expand Down
12 changes: 12 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@
def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

@pytest.mark.parametrize('empty_type', all_empty_string_types)
@pytest.mark.parametrize('num_splits', ['-1', '0', '1', '2'])
def test_split_input_all_empty(empty_type, num_splits):
data_gen = mk_empty_str_gen(empty_type)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'split(a, "AB", ' + num_splits + ')',
'split(a, "C", ' + num_splits + ')',
'split(a, ">>", ' + num_splits + ')',
'split(a, "_", ' + num_splits + ')'),
conf=_regexp_conf)

def test_split_no_limit():
data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils
import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1791,6 +1792,22 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression,

override def doColumnar(str: GpuColumnVector, regex: GpuScalar,
limit: GpuScalar): ColumnVector = {
// TODO when https://github.com/rapidsai/cudf/issues/16453 is fixed remove this workaround
withResource(str.getBase.getData) { data =>
if (data == null || data.getLength <= 0) {
// An empty data for a string means that all of the inputs are either null or an empty
// string. CUDF will return null for all of these, but we should only do that for a
// null input. For all of the others it should be an empty string.
withResource(GpuScalar.from(null, dataType)) { nullArray =>
withResource(GpuScalar.from(
new GenericArrayData(Array(UTF8String.EMPTY_UTF8)), dataType)) { emptyStringArray =>
withResource(str.getBase.isNull) { retNull =>
return retNull.ifElse(nullArray, emptyStringArray)
}
}
}
}
}
limit.getValue.asInstanceOf[Int] match {
case 0 =>
// Same as splitting as many times as possible
Expand All @@ -1802,11 +1819,10 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression,
case 1 =>
// Short circuit GPU and just return a list containing the original input string
withResource(str.getBase.isNull) { isNull =>
withResource(GpuScalar.from(null, DataTypes.createArrayType(DataTypes.StringType))) {
nullStringList =>
withResource(ColumnVector.makeList(str.getBase)) { list =>
isNull.ifElse(nullStringList, list)
}
withResource(GpuScalar.from(null, dataType)) { nullStringList =>
withResource(ColumnVector.makeList(str.getBase)) { list =>
isNull.ifElse(nullStringList, list)
}
}
}
case n =>
Expand Down