diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 8d9dc3151cb..cdd8f6f1105 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -556,7 +556,7 @@ Accelerator supports are described below.
S |
NS |
NS |
-PS not allowed for grouping expressions; UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
+PS not allowed for grouping expressions if containing Struct as child; UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
PS not allowed for grouping expressions; UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
PS not allowed for grouping expressions if containing Array or Map as child; UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT |
NS |
@@ -724,7 +724,7 @@ Accelerator supports are described below.
S |
S |
NS |
-PS Round-robin partitioning is not supported if spark.sql.execution.sortBeforeRepartition is true; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
PS Round-robin partitioning is not supported if spark.sql.execution.sortBeforeRepartition is true; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
PS Round-robin partitioning is not supported for nested structs if spark.sql.execution.sortBeforeRepartition is true; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
NS |
@@ -7737,45 +7737,45 @@ are limited.
None |
project |
input |
- |
- |
- |
- |
- |
S |
S |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
+S |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for TIMESTAMP |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+S |
result |
- |
- |
- |
- |
- |
S |
S |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
+S |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for TIMESTAMP |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+S |
KnownNotNull |
@@ -18594,9 +18594,9 @@ as `a` don't show up in the table. They are controlled by the rules for
S |
NS |
NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, MAP, UDT |
NS |
-NS |
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, MAP, UDT |
NS |
diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py
index ce2ffcf3094..752a461f58f 100644
--- a/integration_tests/src/main/python/hash_aggregate_test.py
+++ b/integration_tests/src/main/python/hash_aggregate_test.py
@@ -128,6 +128,19 @@
('b', FloatGen(nullable=(True, 10.0), special_cases=[(float('nan'), 10.0)])),
('c', LongGen())]
+# grouping single-level lists
+_grpkey_list_with_non_nested_children = [[('a', RepeatSeqGen(ArrayGen(data_gen), length=3)),
+ ('b', IntegerGen())] for data_gen in all_basic_gens + decimal_gens]
+
+#grouping mutliple-level structs with arrays
+_grpkey_nested_structs_with_array_basic_child = [
+ ('a', RepeatSeqGen(StructGen([
+ ['aa', IntegerGen()],
+ ['ab', ArrayGen(IntegerGen())]]),
+ length=20)),
+ ('b', IntegerGen()),
+ ('c', NullGen())]
+
_nan_zero_float_special_cases = [
(float('nan'), 5.0),
(NEG_FLOAT_NAN_MIN_VALUE, 5.0),
@@ -335,7 +348,7 @@ def test_hash_reduction_decimal_overflow_sum(precision):
# some optimizations are conspiring against us.
conf = {'spark.rapids.sql.batchSizeBytes': '128m'})
-@pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn)
+@pytest.mark.parametrize('data_gen', [_grpkey_nested_structs_with_array_basic_child, _longs_with_nulls] + _grpkey_list_with_non_nested_children, ids=idfn)
def test_hash_grpby_sum_count_action(data_gen):
assert_gpu_and_cpu_row_counts_equal(
lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b'))
diff --git a/integration_tests/src/main/python/repart_test.py b/integration_tests/src/main/python/repart_test.py
index 7b77b7be426..b12a680d3eb 100644
--- a/integration_tests/src/main/python/repart_test.py
+++ b/integration_tests/src/main/python/repart_test.py
@@ -214,10 +214,23 @@ def test_round_robin_sort_fallback(data_gen):
lambda spark : gen_df(spark, data_gen).withColumn('extra', lit(1)).repartition(13),
'ShuffleExchangeExec')
+@allow_non_gpu("ProjectExec", "ShuffleExchangeExec")
+@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
+@pytest.mark.parametrize('num_parts', [2, 10, 17, 19, 32], ids=idfn)
+@pytest.mark.parametrize('gen', [([('ag', ArrayGen(StructGen([('b1', long_gen)])))], ['ag'])], ids=idfn)
+def test_hash_repartition_exact_fallback(gen, num_parts):
+ data_gen = gen[0]
+ part_on = gen[1]
+ assert_gpu_fallback_collect(
+ lambda spark : gen_df(spark, data_gen, length=1024) \
+ .repartition(num_parts, *part_on) \
+ .withColumn('id', f.spark_partition_id()) \
+ .selectExpr('*'), "ShuffleExchangeExec")
+
@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
@pytest.mark.parametrize('num_parts', [1, 2, 10, 17, 19, 32], ids=idfn)
@pytest.mark.parametrize('gen', [
- ([('a', boolean_gen)], ['a']),
+ ([('a', boolean_gen)], ['a']),
([('a', byte_gen)], ['a']),
([('a', short_gen)], ['a']),
([('a', int_gen)], ['a']),
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 30a18b6d77f..cf33be44905 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1645,9 +1645,7 @@ object GpuOverrides extends Logging {
}),
expr[KnownFloatingPointNormalized](
"Tag to prevent redundant normalization",
- ExprChecks.unaryProjectInputMatchesOutput(
- TypeSig.DOUBLE + TypeSig.FLOAT,
- TypeSig.DOUBLE + TypeSig.FLOAT),
+ ExprChecks.unaryProjectInputMatchesOutput(TypeSig.all, TypeSig.all),
(a, conf, p, r) => new UnaryExprMeta[KnownFloatingPointNormalized](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
GpuKnownFloatingPointNormalized(child)
@@ -3692,11 +3690,26 @@ object GpuOverrides extends Logging {
// This needs to match what murmur3 supports.
PartChecks(RepeatingParamCheck("hash_key",
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
- TypeSig.STRUCT).nested(), TypeSig.all)),
+ TypeSig.STRUCT + TypeSig.ARRAY).nested(),
+ TypeSig.all)
+ ),
(hp, conf, p, r) => new PartMeta[HashPartitioning](hp, conf, p, r) {
override val childExprs: Seq[BaseExprMeta[_]] =
hp.expressions.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ override def tagPartForGpu(): Unit = {
+ val arrayWithStructsHashing = hp.expressions.exists(e =>
+ TrampolineUtil.dataTypeExistsRecursively(e.dataType,
+ dt => dt match {
+ case ArrayType(_: StructType, _) => true
+ case _ => false
+ })
+ )
+ if (arrayWithStructsHashing) {
+ willNotWorkOnGpu("hashing arrays with structs is not supported")
+ }
+ }
+
override def convertToGpu(): GpuPartitioning =
GpuHashPartitioning(childExprs.map(_.convertToGpu()), hp.numPartitions)
}),
@@ -3912,7 +3925,7 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRUCT, "Round-robin partitioning is not supported for nested " +
s"structs if ${SQLConf.SORT_BEFORE_REPARTITION.key} is true")
.withPsNote(
- Seq(TypeEnum.ARRAY, TypeEnum.MAP),
+ Seq(TypeEnum.MAP),
"Round-robin partitioning is not supported if " +
s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true"),
TypeSig.all),
@@ -3974,10 +3987,12 @@ object GpuOverrides extends Logging {
"The backend for hash based aggregations",
ExecChecks(
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
- TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT)
+ TypeSig.MAP + TypeSig.STRUCT + TypeSig.ARRAY)
.nested()
- .withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP),
+ .withPsNote(TypeEnum.MAP,
"not allowed for grouping expressions")
+ .withPsNote(TypeEnum.ARRAY,
+ "not allowed for grouping expressions if containing Struct as child")
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array or Map as child"),
TypeSig.all),
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
index 3a9c9595ab2..6eb222335cd 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.{ExplainUtils, SortExec, SparkPlan}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.rapids.{CpuToGpuAggregateBufferConverter, CudfAggregate, GpuAggregateExpression, GpuToCpuAggregateBufferConverter}
import org.apache.spark.sql.rapids.execution.{GpuShuffleMeta, TrampolineUtil}
-import org.apache.spark.sql.types.{ArrayType, DataType, MapType}
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
object AggregateUtils {
@@ -852,13 +852,27 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan](
groupingExpressions ++ aggregateExpressions ++ aggregateAttributes ++ resultExpressions
override def tagPlanForGpu(): Unit = {
- // We don't support Arrays and Maps as GroupBy keys yet, even they are nested in Structs. So,
+ // We don't support Maps as GroupBy keys yet, even if they are nested in Structs. So,
// we need to run recursive type check on the structs.
- val arrayOrMapGroupings = agg.groupingExpressions.exists(e =>
+ val mapGroupings = agg.groupingExpressions.exists(e =>
TrampolineUtil.dataTypeExistsRecursively(e.dataType,
- dt => dt.isInstanceOf[ArrayType] || dt.isInstanceOf[MapType]))
- if (arrayOrMapGroupings) {
- willNotWorkOnGpu("ArrayTypes or MapTypes in grouping expressions are not supported")
+ dt => dt.isInstanceOf[MapType]))
+ if (mapGroupings) {
+ willNotWorkOnGpu("MapTypes in grouping expressions are not supported")
+ }
+
+ // We support Arrays as grouping expression but not if the child is a struct. So we need to
+ // run recursive type check on the lists of structs
+ val arrayWithStructsGroupings = agg.groupingExpressions.exists(e =>
+ TrampolineUtil.dataTypeExistsRecursively(e.dataType,
+ dt => dt match {
+ case ArrayType(_: StructType, _) => true
+ case _ => false
+ })
+ )
+ if (arrayWithStructsGroupings) {
+ willNotWorkOnGpu("ArrayTypes with Struct children in grouping expressions are not " +
+ "supported")
}
tagForReplaceMode()
diff --git a/tools/src/main/resources/supportedExprs.csv b/tools/src/main/resources/supportedExprs.csv
index cd37c792ccc..7b93812944c 100644
--- a/tools/src/main/resources/supportedExprs.csv
+++ b/tools/src/main/resources/supportedExprs.csv
@@ -263,8 +263,8 @@ IsNull,S,`isnull`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS
IsNull,S,`isnull`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
JsonToStructs,NS,`from_json`,This is disabled by default because parsing JSON from a column has a large number of issues and should be considered beta quality right now.,project,jsonStr,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
JsonToStructs,NS,`from_json`,This is disabled by default because parsing JSON from a column has a large number of issues and should be considered beta quality right now.,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NS,PS,NS,NA
-KnownFloatingPointNormalized,S, ,None,project,input,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
-KnownFloatingPointNormalized,S, ,None,project,result,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
+KnownFloatingPointNormalized,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S
+KnownFloatingPointNormalized,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S
KnownNotNull,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS
KnownNotNull,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS
Lag,S,`lag`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NS,PS,NS