diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 80506c0e161..f4685bc3aef 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -16771,12 +16771,12 @@ and the accelerator produces the same result.
|
|
|
-S |
+PS the array's child type must also support being cast to string |
|
|
|
|
-PS The array's child type must also support being cast to the desired child type; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS The array's child type must also support being cast to the desired child type(s); UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
|
|
|
@@ -17175,12 +17175,12 @@ and the accelerator produces the same result.
|
|
|
-S |
+PS the array's child type must also support being cast to string |
|
|
|
|
-PS The array's child type must also support being cast to the desired child type; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS The array's child type must also support being cast to the desired child type(s); UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
|
|
|
diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py
index b4c0194174a..62192aba283 100644
--- a/integration_tests/src/main/python/cast_test.py
+++ b/integration_tests/src/main/python/cast_test.py
@@ -174,7 +174,7 @@ def test_cast_long_to_decimal_overflow():
# casting these types to string is not exact match, marked as xfail when testing
not_matched_gens_for_cast_to_string = [float_gen, double_gen, decimal_gen_neg_scale]
# casting these types to string is not supported, marked as xfail when testing
-not_support_gens_for_cast_to_string = decimal_128_gens
+not_support_gens_for_cast_to_string = decimal_128_gens + [MapGen(ByteGen(False), ByteGen())]
single_level_array_gens_for_cast_to_string = [ArrayGen(sub_gen) for sub_gen in basic_gens_for_cast_to_string]
nested_array_gens_for_cast_to_string = [
@@ -188,13 +188,23 @@ def test_cast_long_to_decimal_overflow():
def _assert_cast_to_string_equal (data_gen, conf):
"""
- helper function for casting to string
+ helper function for casting to string of supported type
"""
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")),
conf
)
+def _assert_cast_to_string_fallback (data_gen, conf):
+ """
+ helper function for casting to string of unsupported type
+ """
+ assert_gpu_fallback_collect(
+ lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")),
+ "Cast",
+ conf
+ )
+
@pytest.mark.parametrize('data_gen', all_gens_for_cast_to_string, ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
def test_cast_array_to_string(data_gen, legacy):
@@ -219,12 +229,13 @@ def test_cast_array_with_unmatched_element_to_string(data_gen, legacy):
@pytest.mark.parametrize('data_gen', [ArrayGen(sub) for sub in not_support_gens_for_cast_to_string], ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
-@pytest.mark.xfail(reason='casting this type to string is not supported')
-def test_cast_array_with_unsupported_element_to_string(data_gen, legacy):
- _assert_cast_to_string_equal(
+@allow_non_gpu('ProjectExec', 'Cast', 'Alias')
+def test_cast_array_with_unsupported_element_to_string_fallback(data_gen, legacy):
+ _assert_cast_to_string_fallback(
data_gen,
{"spark.rapids.sql.castDecimalToString.enabled" : 'true',
- "spark.sql.legacy.castComplexTypesToString.enabled": legacy}
+ "spark.sql.legacy.castComplexTypesToString.enabled": legacy,
+ "spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'}
)
@@ -285,11 +296,12 @@ def test_cast_struct_with_unmatched_element_to_string(data_gen, legacy):
@pytest.mark.parametrize('data_gen', [StructGen([["first", element_gen]]) for element_gen in not_support_gens_for_cast_to_string], ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
-@pytest.mark.xfail(reason='casting this type to string is not supported')
-def test_cast_struct_with_unsupported_element_to_string(data_gen, legacy):
- _assert_cast_to_string_equal(
+@allow_non_gpu('ProjectExec', 'Cast', 'Alias')
+def test_cast_struct_with_unsupported_element_to_string_fallback(data_gen, legacy):
+ _assert_cast_to_string_fallback(
data_gen,
- {"spark.rapids.sql.castDecimalToString.enabled" : 'true',
- "spark.sql.legacy.castComplexTypesToString.enabled": legacy}
+ {"spark.rapids.sql.castDecimalToString.enabled" : 'true',
+ "spark.sql.legacy.castComplexTypesToString.enabled": legacy,
+ "spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'}
)
\ No newline at end of file
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
index 41a8cf0fc01..9b8e251a52e 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
@@ -124,6 +124,9 @@ final class CastExprMeta[INPUT <: CastBase](
case (fromChild, toChild) =>
recursiveTagExprForGpuCheck(fromChild.dataType, toChild.dataType, depth + 1)
}
+ case (ArrayType(elementType, _), StringType) =>
+ recursiveTagExprForGpuCheck(elementType, StringType, depth + 1)
+
case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) =>
recursiveTagExprForGpuCheck(nestedFrom, nestedTo, depth + 1)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index 5d03657b714..1706b305fca 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -1270,10 +1270,11 @@ class CastChecks extends ExprChecks {
val calendarChecks: TypeSig = none
val sparkCalendarSig: TypeSig = CALENDAR + STRING
- val arrayChecks: TypeSig = STRING + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL +
+ val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's child type must also support " +
+ "being cast to string") + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL +
ARRAY + BINARY + STRUCT + MAP) +
- psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " +
- "the desired child type")
+ psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to the " +
+ "desired child type(s)")
val sparkArraySig: TypeSig = STRING + ARRAY.nested(all)