From 7a6c423b4d843643b9beb4d4f09d28f58136da07 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Fri, 31 Dec 2021 13:16:22 +0800 Subject: [PATCH 1/2] add recursive type checking for casting array to string add Map to unsupported element types in IT replace xfail by fallback when casting unsupported types to string in IT Signed-off-by: remzi <13716567376yh@gmail.com> --- docs/supported_ops.md | 8 ++--- .../src/main/python/cast_test.py | 34 +++++++++++++------ .../com/nvidia/spark/rapids/GpuCast.scala | 3 ++ .../com/nvidia/spark/rapids/TypeChecks.scala | 7 ++-- 4 files changed, 34 insertions(+), 18 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 80506c0e161..886087b605d 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -16771,12 +16771,12 @@ and the accelerator produces the same result. -S +PS
the array's children must also support being cast to string
-PS
The array's child type must also support being cast to the desired child type;
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
+PS
The array's children must also support being cast to the desired child type(s);
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
@@ -17175,12 +17175,12 @@ and the accelerator produces the same result. -S +PS
the array's children must also support being cast to string
-PS
The array's child type must also support being cast to the desired child type;
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
+PS
The array's children must also support being cast to the desired child type(s);
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index b4c0194174a..62192aba283 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -174,7 +174,7 @@ def test_cast_long_to_decimal_overflow(): # casting these types to string is not exact match, marked as xfail when testing not_matched_gens_for_cast_to_string = [float_gen, double_gen, decimal_gen_neg_scale] # casting these types to string is not supported, marked as xfail when testing -not_support_gens_for_cast_to_string = decimal_128_gens +not_support_gens_for_cast_to_string = decimal_128_gens + [MapGen(ByteGen(False), ByteGen())] single_level_array_gens_for_cast_to_string = [ArrayGen(sub_gen) for sub_gen in basic_gens_for_cast_to_string] nested_array_gens_for_cast_to_string = [ @@ -188,13 +188,23 @@ def test_cast_long_to_decimal_overflow(): def _assert_cast_to_string_equal (data_gen, conf): """ - helper function for casting to string + helper function for casting to string of supported type """ assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")), conf ) +def _assert_cast_to_string_fallback (data_gen, conf): + """ + helper function for casting to string of unsupported type + """ + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")), + "Cast", + conf + ) + @pytest.mark.parametrize('data_gen', all_gens_for_cast_to_string, ids=idfn) @pytest.mark.parametrize('legacy', ['true', 'false']) def test_cast_array_to_string(data_gen, legacy): @@ -219,12 +229,13 @@ def test_cast_array_with_unmatched_element_to_string(data_gen, legacy): @pytest.mark.parametrize('data_gen', [ArrayGen(sub) for sub in not_support_gens_for_cast_to_string], ids=idfn) @pytest.mark.parametrize('legacy', ['true', 'false']) -@pytest.mark.xfail(reason='casting this type to string is not supported') -def test_cast_array_with_unsupported_element_to_string(data_gen, legacy): - _assert_cast_to_string_equal( +@allow_non_gpu('ProjectExec', 'Cast', 'Alias') +def test_cast_array_with_unsupported_element_to_string_fallback(data_gen, legacy): + _assert_cast_to_string_fallback( data_gen, {"spark.rapids.sql.castDecimalToString.enabled" : 'true', - "spark.sql.legacy.castComplexTypesToString.enabled": legacy} + "spark.sql.legacy.castComplexTypesToString.enabled": legacy, + "spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'} ) @@ -285,11 +296,12 @@ def test_cast_struct_with_unmatched_element_to_string(data_gen, legacy): @pytest.mark.parametrize('data_gen', [StructGen([["first", element_gen]]) for element_gen in not_support_gens_for_cast_to_string], ids=idfn) @pytest.mark.parametrize('legacy', ['true', 'false']) -@pytest.mark.xfail(reason='casting this type to string is not supported') -def test_cast_struct_with_unsupported_element_to_string(data_gen, legacy): - _assert_cast_to_string_equal( +@allow_non_gpu('ProjectExec', 'Cast', 'Alias') +def test_cast_struct_with_unsupported_element_to_string_fallback(data_gen, legacy): + _assert_cast_to_string_fallback( data_gen, - {"spark.rapids.sql.castDecimalToString.enabled" : 'true', - "spark.sql.legacy.castComplexTypesToString.enabled": legacy} + {"spark.rapids.sql.castDecimalToString.enabled" : 'true', + "spark.sql.legacy.castComplexTypesToString.enabled": legacy, + "spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'} ) \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 41a8cf0fc01..68c557fbf8d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -124,6 +124,9 @@ final class CastExprMeta[INPUT <: CastBase]( case (fromChild, toChild) => recursiveTagExprForGpuCheck(fromChild.dataType, toChild.dataType, depth + 1) } + case (ArrayType(element_type, _), StringType) => + recursiveTagExprForGpuCheck(element_type, StringType, depth + 1) + case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) => recursiveTagExprForGpuCheck(nestedFrom, nestedTo, depth + 1) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 5687dfcc6fd..5d294c590b0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1270,10 +1270,11 @@ class CastChecks extends ExprChecks { val calendarChecks: TypeSig = none val sparkCalendarSig: TypeSig = CALENDAR + STRING - val arrayChecks: TypeSig = STRING + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL + + val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's children must also support " + + "being cast to string") + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL + ARRAY + BINARY + STRUCT + MAP) + - psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " + - "the desired child type") + psNote(TypeEnum.ARRAY, "The array's children must also support being cast to the " + + "desired child type(s)") val sparkArraySig: TypeSig = STRING + ARRAY.nested(all) From 86e8b30a8668cec37a4ee21b6d992feca9330cb3 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Tue, 4 Jan 2022 09:30:56 +0800 Subject: [PATCH 2/2] fix some minor issues Signed-off-by: remzi <13716567376yh@gmail.com> --- docs/supported_ops.md | 8 ++++---- .../src/main/scala/com/nvidia/spark/rapids/GpuCast.scala | 4 ++-- .../main/scala/com/nvidia/spark/rapids/TypeChecks.scala | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 886087b605d..f4685bc3aef 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -16771,12 +16771,12 @@ and the accelerator produces the same result. -PS
the array's children must also support being cast to string
+PS
the array's child type must also support being cast to string
-PS
The array's children must also support being cast to the desired child type(s);
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
+PS
The array's child type must also support being cast to the desired child type(s);
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
@@ -17175,12 +17175,12 @@ and the accelerator produces the same result. -PS
the array's children must also support being cast to string
+PS
the array's child type must also support being cast to string
-PS
The array's children must also support being cast to the desired child type(s);
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
+PS
The array's child type must also support being cast to the desired child type(s);
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 68c557fbf8d..9b8e251a52e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -124,8 +124,8 @@ final class CastExprMeta[INPUT <: CastBase]( case (fromChild, toChild) => recursiveTagExprForGpuCheck(fromChild.dataType, toChild.dataType, depth + 1) } - case (ArrayType(element_type, _), StringType) => - recursiveTagExprForGpuCheck(element_type, StringType, depth + 1) + case (ArrayType(elementType, _), StringType) => + recursiveTagExprForGpuCheck(elementType, StringType, depth + 1) case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) => recursiveTagExprForGpuCheck(nestedFrom, nestedTo, depth + 1) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 5d294c590b0..61f31d52277 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1270,10 +1270,10 @@ class CastChecks extends ExprChecks { val calendarChecks: TypeSig = none val sparkCalendarSig: TypeSig = CALENDAR + STRING - val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's children must also support " + + val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's child type must also support " + "being cast to string") + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL + ARRAY + BINARY + STRUCT + MAP) + - psNote(TypeEnum.ARRAY, "The array's children must also support being cast to the " + + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to the " + "desired child type(s)") val sparkArraySig: TypeSig = STRING + ARRAY.nested(all)