From d3dc49679004e56a515ad1d0e5ab27a113ae1b56 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Mon, 5 Aug 2024 09:19:59 -0500 Subject: [PATCH] Add work around for string split with empty input. (#11292) Signed-off-by: Robert (Bobby) Evans --- .../src/main/python/array_test.py | 9 ++++++- integration_tests/src/main/python/data_gen.py | 21 +++++++++++++-- integration_tests/src/main/python/map_test.py | 10 +++++++ .../src/main/python/string_test.py | 12 +++++++++ .../spark/sql/rapids/stringFunctions.scala | 26 +++++++++++++++---- 5 files changed, 70 insertions(+), 8 deletions(-) diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index a463d1af453..906f74567dc 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -186,6 +186,13 @@ def test_make_array(data_gen): 'array(b, a, null, {}, {})'.format(s1, s2), 'array(array(b, a, null, {}, {}), array(a), array(null))'.format(s1, s2))) +@pytest.mark.parametrize('empty_type', all_empty_string_types) +def test_make_array_empty_input(empty_type): + data_gen = mk_empty_str_gen(empty_type) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : binary_op_df(spark, data_gen).selectExpr( + 'array(a)', + 'array(a, b)')) @pytest.mark.parametrize('data_gen', single_level_array_gens, ids=idfn) def test_orderby_array_unique(data_gen): @@ -795,4 +802,4 @@ def test_map_from_arrays_length_exception(): lambda spark: gen_df(spark, gen).selectExpr( 'map_from_arrays(array(1), a)').collect(), conf={'spark.sql.mapKeyDedupPolicy':'EXCEPTION'}, - error_message = "The key array and value array of MapData must have the same length") \ No newline at end of file + error_message = "The key array and value array of MapData must have the same length") diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index fb26edfc20e..c17142bded5 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -15,6 +15,7 @@ import copy from datetime import date, datetime, timedelta, timezone from decimal import * +from enum import Enum import math from pyspark.context import SparkContext from pyspark.sql import Row @@ -744,8 +745,8 @@ def contains_ts(self): class NullGen(DataGen): """Generate NullType values""" - def __init__(self): - super().__init__(NullType(), nullable=True) + def __init__(self, dt = NullType()): + super().__init__(dt, nullable=True) def start(self, rand): def make_null(): @@ -1044,6 +1045,22 @@ def gen_scalars_for_sql(data_gen, count, seed=None, force_no_nulls=False): spark_type = data_gen.data_type return (_convert_to_sql(spark_type, src.gen(force_no_nulls=force_no_nulls)) for i in range(0, count)) +class EmptyStringType(Enum): + ALL_NULL = 1 + ALL_EMPTY = 2 + MIXED = 3 + +all_empty_string_types = EmptyStringType.__members__.values() + +empty_string_gens_map = { + EmptyStringType.ALL_NULL : lambda: NullGen(StringType()), + EmptyStringType.ALL_EMPTY : lambda: StringGen("", nullable=False), + EmptyStringType.MIXED : lambda: StringGen("", nullable=True) +} + +def mk_empty_str_gen(empty_type): + return empty_string_gens_map[empty_type]() + byte_gen = ByteGen() short_gen = ShortGen() int_gen = IntegerGen() diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py index 72e92280d43..fa647761b62 100644 --- a/integration_tests/src/main/python/map_test.py +++ b/integration_tests/src/main/python/map_test.py @@ -440,6 +440,16 @@ def test_str_to_map_expr_with_all_regex_delimiters(): ), conf={'spark.sql.mapKeyDedupPolicy': 'LAST_WIN'}) +@pytest.mark.parametrize('empty_type', all_empty_string_types) +def test_str_to_map_input_all_empty(empty_type): + data_gen = mk_empty_str_gen(empty_type) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'str_to_map(a) as m0', + 'str_to_map(a, ",") as m1', + 'str_to_map(a, ",", ":") as m2' + ), conf={'spark.sql.mapKeyDedupPolicy': 'LAST_WIN'}) + @pytest.mark.skipif(not is_before_spark_330(), reason="Only in Spark 3.1.1+ (< 3.3.0) + ANSI mode, map key throws on no such element") @pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 7b17b9af700..4ae4a827aa0 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -30,6 +30,18 @@ def mk_str_gen(pattern): return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}') +@pytest.mark.parametrize('empty_type', all_empty_string_types) +@pytest.mark.parametrize('num_splits', ['-1', '0', '1', '2']) +def test_split_input_all_empty(empty_type, num_splits): + data_gen = mk_empty_str_gen(empty_type) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'split(a, "AB", ' + num_splits + ')', + 'split(a, "C", ' + num_splits + ')', + 'split(a, ">>", ' + num_splits + ')', + 'split(a, "_", ' + num_splits + ')'), + conf=_regexp_conf) + def test_split_no_limit(): data_gen = mk_str_gen('([ABC]{0,3}_?){0,7}') assert_gpu_and_cpu_are_equal_collect( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 4e243c79736..874f38a21c4 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -32,6 +32,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -1791,6 +1792,22 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression, override def doColumnar(str: GpuColumnVector, regex: GpuScalar, limit: GpuScalar): ColumnVector = { + // TODO when https://github.com/rapidsai/cudf/issues/16453 is fixed remove this workaround + withResource(str.getBase.getData) { data => + if (data == null || data.getLength <= 0) { + // An empty data for a string means that all of the inputs are either null or an empty + // string. CUDF will return null for all of these, but we should only do that for a + // null input. For all of the others it should be an empty string. + withResource(GpuScalar.from(null, dataType)) { nullArray => + withResource(GpuScalar.from( + new GenericArrayData(Array(UTF8String.EMPTY_UTF8)), dataType)) { emptyStringArray => + withResource(str.getBase.isNull) { retNull => + return retNull.ifElse(nullArray, emptyStringArray) + } + } + } + } + } limit.getValue.asInstanceOf[Int] match { case 0 => // Same as splitting as many times as possible @@ -1802,11 +1819,10 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression, case 1 => // Short circuit GPU and just return a list containing the original input string withResource(str.getBase.isNull) { isNull => - withResource(GpuScalar.from(null, DataTypes.createArrayType(DataTypes.StringType))) { - nullStringList => - withResource(ColumnVector.makeList(str.getBase)) { list => - isNull.ifElse(nullStringList, list) - } + withResource(GpuScalar.from(null, dataType)) { nullStringList => + withResource(ColumnVector.makeList(str.getBase)) { list => + isNull.ifElse(nullStringList, list) + } } } case n =>