diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 80506c0e161..f4685bc3aef 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -16771,12 +16771,12 @@ and the accelerator produces the same result. -S +PS
the array's child type must also support being cast to string
-PS
The array's child type must also support being cast to the desired child type;
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
+PS
The array's child type must also support being cast to the desired child type(s);
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
@@ -17175,12 +17175,12 @@ and the accelerator produces the same result. -S +PS
the array's child type must also support being cast to string
-PS
The array's child type must also support being cast to the desired child type;
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
+PS
The array's child type must also support being cast to the desired child type(s);
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index b4c0194174a..62192aba283 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -174,7 +174,7 @@ def test_cast_long_to_decimal_overflow(): # casting these types to string is not exact match, marked as xfail when testing not_matched_gens_for_cast_to_string = [float_gen, double_gen, decimal_gen_neg_scale] # casting these types to string is not supported, marked as xfail when testing -not_support_gens_for_cast_to_string = decimal_128_gens +not_support_gens_for_cast_to_string = decimal_128_gens + [MapGen(ByteGen(False), ByteGen())] single_level_array_gens_for_cast_to_string = [ArrayGen(sub_gen) for sub_gen in basic_gens_for_cast_to_string] nested_array_gens_for_cast_to_string = [ @@ -188,13 +188,23 @@ def test_cast_long_to_decimal_overflow(): def _assert_cast_to_string_equal (data_gen, conf): """ - helper function for casting to string + helper function for casting to string of supported type """ assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")), conf ) +def _assert_cast_to_string_fallback (data_gen, conf): + """ + helper function for casting to string of unsupported type + """ + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")), + "Cast", + conf + ) + @pytest.mark.parametrize('data_gen', all_gens_for_cast_to_string, ids=idfn) @pytest.mark.parametrize('legacy', ['true', 'false']) def test_cast_array_to_string(data_gen, legacy): @@ -219,12 +229,13 @@ def test_cast_array_with_unmatched_element_to_string(data_gen, legacy): @pytest.mark.parametrize('data_gen', [ArrayGen(sub) for sub in not_support_gens_for_cast_to_string], ids=idfn) @pytest.mark.parametrize('legacy', ['true', 'false']) -@pytest.mark.xfail(reason='casting this type to string is not supported') -def test_cast_array_with_unsupported_element_to_string(data_gen, legacy): - _assert_cast_to_string_equal( +@allow_non_gpu('ProjectExec', 'Cast', 'Alias') +def test_cast_array_with_unsupported_element_to_string_fallback(data_gen, legacy): + _assert_cast_to_string_fallback( data_gen, {"spark.rapids.sql.castDecimalToString.enabled" : 'true', - "spark.sql.legacy.castComplexTypesToString.enabled": legacy} + "spark.sql.legacy.castComplexTypesToString.enabled": legacy, + "spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'} ) @@ -285,11 +296,12 @@ def test_cast_struct_with_unmatched_element_to_string(data_gen, legacy): @pytest.mark.parametrize('data_gen', [StructGen([["first", element_gen]]) for element_gen in not_support_gens_for_cast_to_string], ids=idfn) @pytest.mark.parametrize('legacy', ['true', 'false']) -@pytest.mark.xfail(reason='casting this type to string is not supported') -def test_cast_struct_with_unsupported_element_to_string(data_gen, legacy): - _assert_cast_to_string_equal( +@allow_non_gpu('ProjectExec', 'Cast', 'Alias') +def test_cast_struct_with_unsupported_element_to_string_fallback(data_gen, legacy): + _assert_cast_to_string_fallback( data_gen, - {"spark.rapids.sql.castDecimalToString.enabled" : 'true', - "spark.sql.legacy.castComplexTypesToString.enabled": legacy} + {"spark.rapids.sql.castDecimalToString.enabled" : 'true', + "spark.sql.legacy.castComplexTypesToString.enabled": legacy, + "spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'} ) \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 41a8cf0fc01..9b8e251a52e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -124,6 +124,9 @@ final class CastExprMeta[INPUT <: CastBase]( case (fromChild, toChild) => recursiveTagExprForGpuCheck(fromChild.dataType, toChild.dataType, depth + 1) } + case (ArrayType(elementType, _), StringType) => + recursiveTagExprForGpuCheck(elementType, StringType, depth + 1) + case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) => recursiveTagExprForGpuCheck(nestedFrom, nestedTo, depth + 1) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 5d03657b714..1706b305fca 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1270,10 +1270,11 @@ class CastChecks extends ExprChecks { val calendarChecks: TypeSig = none val sparkCalendarSig: TypeSig = CALENDAR + STRING - val arrayChecks: TypeSig = STRING + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL + + val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's child type must also support " + + "being cast to string") + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL + ARRAY + BINARY + STRUCT + MAP) + - psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " + - "the desired child type") + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to the " + + "desired child type(s)") val sparkArraySig: TypeSig = STRING + ARRAY.nested(all)