From 7a6c423b4d843643b9beb4d4f09d28f58136da07 Mon Sep 17 00:00:00 2001
From: remzi <13716567376yh@gmail.com>
Date: Fri, 31 Dec 2021 13:16:22 +0800
Subject: [PATCH 1/2] add recursive type checking for casting array to string
add Map to unsupported element types in IT replace xfail by fallback when
casting unsupported types to string in IT
Signed-off-by: remzi <13716567376yh@gmail.com>
---
docs/supported_ops.md | 8 ++---
.../src/main/python/cast_test.py | 34 +++++++++++++------
.../com/nvidia/spark/rapids/GpuCast.scala | 3 ++
.../com/nvidia/spark/rapids/TypeChecks.scala | 7 ++--
4 files changed, 34 insertions(+), 18 deletions(-)
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 80506c0e161..886087b605d 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -16771,12 +16771,12 @@ and the accelerator produces the same result.
|
|
|
-S |
+PS the array's children must also support being cast to string |
|
|
|
|
-PS The array's child type must also support being cast to the desired child type; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS The array's children must also support being cast to the desired child type(s); UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
|
|
|
@@ -17175,12 +17175,12 @@ and the accelerator produces the same result.
|
|
|
-S |
+PS the array's children must also support being cast to string |
|
|
|
|
-PS The array's child type must also support being cast to the desired child type; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS The array's children must also support being cast to the desired child type(s); UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
|
|
|
diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py
index b4c0194174a..62192aba283 100644
--- a/integration_tests/src/main/python/cast_test.py
+++ b/integration_tests/src/main/python/cast_test.py
@@ -174,7 +174,7 @@ def test_cast_long_to_decimal_overflow():
# casting these types to string is not exact match, marked as xfail when testing
not_matched_gens_for_cast_to_string = [float_gen, double_gen, decimal_gen_neg_scale]
# casting these types to string is not supported, marked as xfail when testing
-not_support_gens_for_cast_to_string = decimal_128_gens
+not_support_gens_for_cast_to_string = decimal_128_gens + [MapGen(ByteGen(False), ByteGen())]
single_level_array_gens_for_cast_to_string = [ArrayGen(sub_gen) for sub_gen in basic_gens_for_cast_to_string]
nested_array_gens_for_cast_to_string = [
@@ -188,13 +188,23 @@ def test_cast_long_to_decimal_overflow():
def _assert_cast_to_string_equal (data_gen, conf):
"""
- helper function for casting to string
+ helper function for casting to string of supported type
"""
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")),
conf
)
+def _assert_cast_to_string_fallback (data_gen, conf):
+ """
+ helper function for casting to string of unsupported type
+ """
+ assert_gpu_fallback_collect(
+ lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")),
+ "Cast",
+ conf
+ )
+
@pytest.mark.parametrize('data_gen', all_gens_for_cast_to_string, ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
def test_cast_array_to_string(data_gen, legacy):
@@ -219,12 +229,13 @@ def test_cast_array_with_unmatched_element_to_string(data_gen, legacy):
@pytest.mark.parametrize('data_gen', [ArrayGen(sub) for sub in not_support_gens_for_cast_to_string], ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
-@pytest.mark.xfail(reason='casting this type to string is not supported')
-def test_cast_array_with_unsupported_element_to_string(data_gen, legacy):
- _assert_cast_to_string_equal(
+@allow_non_gpu('ProjectExec', 'Cast', 'Alias')
+def test_cast_array_with_unsupported_element_to_string_fallback(data_gen, legacy):
+ _assert_cast_to_string_fallback(
data_gen,
{"spark.rapids.sql.castDecimalToString.enabled" : 'true',
- "spark.sql.legacy.castComplexTypesToString.enabled": legacy}
+ "spark.sql.legacy.castComplexTypesToString.enabled": legacy,
+ "spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'}
)
@@ -285,11 +296,12 @@ def test_cast_struct_with_unmatched_element_to_string(data_gen, legacy):
@pytest.mark.parametrize('data_gen', [StructGen([["first", element_gen]]) for element_gen in not_support_gens_for_cast_to_string], ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
-@pytest.mark.xfail(reason='casting this type to string is not supported')
-def test_cast_struct_with_unsupported_element_to_string(data_gen, legacy):
- _assert_cast_to_string_equal(
+@allow_non_gpu('ProjectExec', 'Cast', 'Alias')
+def test_cast_struct_with_unsupported_element_to_string_fallback(data_gen, legacy):
+ _assert_cast_to_string_fallback(
data_gen,
- {"spark.rapids.sql.castDecimalToString.enabled" : 'true',
- "spark.sql.legacy.castComplexTypesToString.enabled": legacy}
+ {"spark.rapids.sql.castDecimalToString.enabled" : 'true',
+ "spark.sql.legacy.castComplexTypesToString.enabled": legacy,
+ "spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'}
)
\ No newline at end of file
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
index 41a8cf0fc01..68c557fbf8d 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
@@ -124,6 +124,9 @@ final class CastExprMeta[INPUT <: CastBase](
case (fromChild, toChild) =>
recursiveTagExprForGpuCheck(fromChild.dataType, toChild.dataType, depth + 1)
}
+ case (ArrayType(element_type, _), StringType) =>
+ recursiveTagExprForGpuCheck(element_type, StringType, depth + 1)
+
case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) =>
recursiveTagExprForGpuCheck(nestedFrom, nestedTo, depth + 1)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index 5687dfcc6fd..5d294c590b0 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -1270,10 +1270,11 @@ class CastChecks extends ExprChecks {
val calendarChecks: TypeSig = none
val sparkCalendarSig: TypeSig = CALENDAR + STRING
- val arrayChecks: TypeSig = STRING + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL +
+ val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's children must also support " +
+ "being cast to string") + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL +
ARRAY + BINARY + STRUCT + MAP) +
- psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " +
- "the desired child type")
+ psNote(TypeEnum.ARRAY, "The array's children must also support being cast to the " +
+ "desired child type(s)")
val sparkArraySig: TypeSig = STRING + ARRAY.nested(all)
From 86e8b30a8668cec37a4ee21b6d992feca9330cb3 Mon Sep 17 00:00:00 2001
From: remzi <13716567376yh@gmail.com>
Date: Tue, 4 Jan 2022 09:30:56 +0800
Subject: [PATCH 2/2] fix some minor issues
Signed-off-by: remzi <13716567376yh@gmail.com>
---
docs/supported_ops.md | 8 ++++----
.../src/main/scala/com/nvidia/spark/rapids/GpuCast.scala | 4 ++--
.../main/scala/com/nvidia/spark/rapids/TypeChecks.scala | 4 ++--
3 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 886087b605d..f4685bc3aef 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -16771,12 +16771,12 @@ and the accelerator produces the same result.
|
|
|
-PS the array's children must also support being cast to string |
+PS the array's child type must also support being cast to string |
|
|
|
|
-PS The array's children must also support being cast to the desired child type(s); UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS The array's child type must also support being cast to the desired child type(s); UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
|
|
|
@@ -17175,12 +17175,12 @@ and the accelerator produces the same result.
|
|
|
-PS the array's children must also support being cast to string |
+PS the array's child type must also support being cast to string |
|
|
|
|
-PS The array's children must also support being cast to the desired child type(s); UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS The array's child type must also support being cast to the desired child type(s); UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
|
|
|
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
index 68c557fbf8d..9b8e251a52e 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
@@ -124,8 +124,8 @@ final class CastExprMeta[INPUT <: CastBase](
case (fromChild, toChild) =>
recursiveTagExprForGpuCheck(fromChild.dataType, toChild.dataType, depth + 1)
}
- case (ArrayType(element_type, _), StringType) =>
- recursiveTagExprForGpuCheck(element_type, StringType, depth + 1)
+ case (ArrayType(elementType, _), StringType) =>
+ recursiveTagExprForGpuCheck(elementType, StringType, depth + 1)
case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) =>
recursiveTagExprForGpuCheck(nestedFrom, nestedTo, depth + 1)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index 5d294c590b0..61f31d52277 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -1270,10 +1270,10 @@ class CastChecks extends ExprChecks {
val calendarChecks: TypeSig = none
val sparkCalendarSig: TypeSig = CALENDAR + STRING
- val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's children must also support " +
+ val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's child type must also support " +
"being cast to string") + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL +
ARRAY + BINARY + STRUCT + MAP) +
- psNote(TypeEnum.ARRAY, "The array's children must also support being cast to the " +
+ psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to the " +
"desired child type(s)")
val sparkArraySig: TypeSig = STRING + ARRAY.nested(all)