Skip to content

Commit

Permalink
Add support for arrays in hashaggregate [databricks] (#7465)
Browse files Browse the repository at this point in the history
* Revert "Revert "Add support for arrays in hashaggregate [databricks] (#6066)" (#6679)"

This reverts commit c05ac2d and adds tests 

* Add test for aggregation on array

* updated docs

---------

Signed-off-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri authored Jul 13, 2023
1 parent 04d9080 commit e7817e4
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 58 deletions.
76 changes: 38 additions & 38 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ Accelerator supports are described below.
<td>S</td>
<td><em>PS<br/>not allowed for grouping expressions</em></td>
<td><b>NS</b></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions if containing Struct as child;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>not allowed for grouping expressions if containing Array, Map, or Binary as child;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -748,7 +748,7 @@ Accelerator supports are described below.
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><em>PS<br/>Round-robin partitioning is not supported if spark.sql.execution.sortBeforeRepartition is true;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>Round-robin partitioning is not supported if spark.sql.execution.sortBeforeRepartition is true;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>Round-robin partitioning is not supported for nested structs if spark.sql.execution.sortBeforeRepartition is true;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -8042,45 +8042,45 @@ are limited.
<td rowSpan="2">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP</em></td>
<td>S</td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP</em></td>
<td>S</td>
</tr>
<tr>
<td rowSpan="2">KnownNotNull</td>
Expand Down Expand Up @@ -10046,9 +10046,9 @@ are limited.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, MAP, UDT</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, MAP, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -19270,9 +19270,9 @@ as `a` don't show up in the table. They are controlled by the rules for
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, MAP, UDT</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, MAP, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
29 changes: 27 additions & 2 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,22 @@
('b', FloatGen(nullable=(True, 10.0), special_cases=[(float('nan'), 10.0)])),
('c', LongGen())]

# grouping single-level lists
# StringGen for the value being aggregated will force CUDF to do a sort based aggregation internally.
_grpkey_list_with_non_nested_children = [[('a', RepeatSeqGen(ArrayGen(data_gen), length=3)),
('b', IntegerGen())] for data_gen in all_basic_gens + decimal_gens] + \
[[('a', RepeatSeqGen(ArrayGen(data_gen), length=3)),
('b', StringGen())] for data_gen in all_basic_gens + decimal_gens]

#grouping mutliple-level structs with arrays
_grpkey_nested_structs_with_array_basic_child = [[
('a', RepeatSeqGen(StructGen([
['aa', IntegerGen()],
['ab', ArrayGen(IntegerGen())]]),
length=20)),
('b', IntegerGen()),
('c', NullGen())]]

_nan_zero_float_special_cases = [
(float('nan'), 5.0),
(NEG_FLOAT_NAN_MIN_VALUE, 5.0),
Expand Down Expand Up @@ -324,6 +340,14 @@ def test_hash_grpby_sum_count_action(data_gen):
lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b'))
)

@allow_non_gpu("SortAggregateExec", "SortExec", "ShuffleExchangeExec")
@ignore_order
@pytest.mark.parametrize('data_gen', _grpkey_nested_structs_with_array_basic_child + _grpkey_list_with_non_nested_children, ids=idfn)
def test_hash_grpby_list_min_max(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100).coalesce(1).groupby('a').agg(f.min('b'), f.max('b'))
)

@pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn)
def test_hash_reduction_sum_count_action(data_gen):
assert_gpu_and_cpu_row_counts_equal(
Expand Down Expand Up @@ -1199,7 +1223,9 @@ def test_agg_count(data_gen, count_func):
@ignore_order(local=True)
@allow_non_gpu('HashAggregateExec', 'Alias', 'AggregateExpression', 'Cast',
'HashPartitioning', 'ShuffleExchangeExec', 'Count')
@pytest.mark.parametrize('data_gen', array_gens_sample + [binary_gen], ids=idfn)
@pytest.mark.parametrize('data_gen',
[ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))
, binary_gen], ids=idfn)
@pytest.mark.parametrize('count_func', [f.count, f.countDistinct])
def test_groupby_list_types_fallback(data_gen, count_func):
assert_gpu_fallback_collect(
Expand Down Expand Up @@ -1718,7 +1744,6 @@ def do_it(spark):
assert_gpu_and_cpu_are_equal_collect(do_it,
conf={'spark.sql.ansi.enabled': 'true'})


# Tests for standard deviation and variance aggregations.
@ignore_order(local=True)
@approximate_float
Expand Down
19 changes: 17 additions & 2 deletions integration_tests/src/main/python/repart_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,23 @@ def test_round_robin_sort_fallback(data_gen):
lambda spark : gen_df(spark, data_gen).withColumn('extra', lit(1)).repartition(13),
'ShuffleExchangeExec')

@allow_non_gpu("ProjectExec", "ShuffleExchangeExec")
@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
@pytest.mark.parametrize('num_parts', [2, 10, 17, 19, 32], ids=idfn)
@pytest.mark.parametrize('gen', [([('ag', ArrayGen(StructGen([('b1', long_gen)])))], ['ag'])], ids=idfn)
def test_hash_repartition_exact_fallback(gen, num_parts):
data_gen = gen[0]
part_on = gen[1]
assert_gpu_fallback_collect(
lambda spark : gen_df(spark, data_gen, length=1024) \
.repartition(num_parts, *part_on) \
.withColumn('id', f.spark_partition_id()) \
.selectExpr('*'), "ShuffleExchangeExec")

@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
@pytest.mark.parametrize('num_parts', [1, 2, 10, 17, 19, 32], ids=idfn)
@pytest.mark.parametrize('gen', [
([('a', boolean_gen)], ['a']),
([('a', boolean_gen)], ['a']),
([('a', byte_gen)], ['a']),
([('a', short_gen)], ['a']),
([('a', int_gen)], ['a']),
Expand All @@ -235,7 +248,9 @@ def test_round_robin_sort_fallback(data_gen):
([('a', long_gen), ('b', StructGen([('b1', long_gen)]))], ['a']),
([('a', long_gen), ('b', ArrayGen(long_gen, max_length=2))], ['a']),
([('a', byte_gen)], [f.col('a') - 5]),
([('a', long_gen)], [f.col('a') + 15]),
([('a', long_gen)], [f.col('a') + 15]),
([('a', ArrayGen(long_gen, max_length=2)), ('b', long_gen)], ['a']),
([('a', StructGen([('aa', ArrayGen(long_gen, max_length=2))])), ('b', long_gen)], ['a']),
([('a', byte_gen), ('b', boolean_gen)], ['a', 'b']),
([('a', short_gen), ('b', string_gen)], ['a', 'b']),
([('a', int_gen), ('b', byte_gen)], ['a', 'b']),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1565,9 +1565,7 @@ object GpuOverrides extends Logging {
}),
expr[KnownFloatingPointNormalized](
"Tag to prevent redundant normalization",
ExprChecks.unaryProjectInputMatchesOutput(
TypeSig.DOUBLE + TypeSig.FLOAT,
TypeSig.DOUBLE + TypeSig.FLOAT),
ExprChecks.unaryProjectInputMatchesOutput(TypeSig.all, TypeSig.all),
(a, conf, p, r) => new UnaryExprMeta[KnownFloatingPointNormalized](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
GpuKnownFloatingPointNormalized(child)
Expand Down Expand Up @@ -3070,7 +3068,7 @@ object GpuOverrides extends Logging {
ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT,
repeatingParamCheck = Some(RepeatingParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.STRUCT).nested(), TypeSig.all))),
TypeSig.STRUCT + TypeSig.ARRAY).nested(), TypeSig.all))),
(a, conf, p, r) => new ExprMeta[Murmur3Hash](a, conf, p, r) {
override val childExprs: Seq[BaseExprMeta[_]] = a.children
.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
Expand Down Expand Up @@ -3592,11 +3590,26 @@ object GpuOverrides extends Logging {
// This needs to match what murmur3 supports.
PartChecks(RepeatingParamCheck("hash_key",
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.STRUCT).nested(), TypeSig.all)),
TypeSig.STRUCT + TypeSig.ARRAY).nested(),
TypeSig.all)
),
(hp, conf, p, r) => new PartMeta[HashPartitioning](hp, conf, p, r) {
override val childExprs: Seq[BaseExprMeta[_]] =
hp.expressions.map(GpuOverrides.wrapExpr(_, conf, Some(this)))

override def tagPartForGpu(): Unit = {
val arrayWithStructsHashing = hp.expressions.exists(e =>
TrampolineUtil.dataTypeExistsRecursively(e.dataType,
dt => dt match {
case ArrayType(_: StructType, _) => true
case _ => false
})
)
if (arrayWithStructsHashing) {
willNotWorkOnGpu("hashing arrays with structs is not supported")
}
}

override def convertToGpu(): GpuPartitioning =
GpuHashPartitioning(childExprs.map(_.convertToGpu()), hp.numPartitions)
}),
Expand Down Expand Up @@ -3844,7 +3857,7 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRUCT, "Round-robin partitioning is not supported for nested " +
s"structs if ${SQLConf.SORT_BEFORE_REPARTITION.key} is true")
.withPsNote(
Seq(TypeEnum.ARRAY, TypeEnum.MAP),
Seq(TypeEnum.MAP),
"Round-robin partitioning is not supported if " +
s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true"),
TypeSig.all),
Expand Down Expand Up @@ -3909,8 +3922,10 @@ object GpuOverrides extends Logging {
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY +
TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT)
.nested()
.withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP, TypeEnum.BINARY),
.withPsNote(Seq(TypeEnum.MAP, TypeEnum.BINARY),
"not allowed for grouping expressions")
.withPsNote(TypeEnum.ARRAY,
"not allowed for grouping expressions if containing Struct as child")
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array, Map, or Binary as child"),
TypeSig.all),
Expand Down
25 changes: 19 additions & 6 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1036,17 +1036,30 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan](
groupingExpressions ++ aggregateExpressions ++ aggregateAttributes ++ resultExpressions

override def tagPlanForGpu(): Unit = {
// We don't support Arrays and Maps as GroupBy keys yet, even they are nested in Structs. So,
// We don't support Maps as GroupBy keys yet, even if they are nested in Structs. So,
// we need to run recursive type check on the structs.
val listTypeGroupings = agg.groupingExpressions.exists(e =>
val mapOrBinaryGroupings = agg.groupingExpressions.exists(e =>
TrampolineUtil.dataTypeExistsRecursively(e.dataType,
dt => dt.isInstanceOf[ArrayType] || dt.isInstanceOf[MapType]
|| dt.isInstanceOf[BinaryType]))
if (listTypeGroupings) {
willNotWorkOnGpu("ArrayType, MapType, or BinaryType " +
dt => dt.isInstanceOf[MapType] || dt.isInstanceOf[BinaryType]))
if (mapOrBinaryGroupings) {
willNotWorkOnGpu("MapType, or BinaryType " +
"in grouping expressions are not supported")
}

// We support Arrays as grouping expression but not if the child is a struct. So we need to
// run recursive type check on the lists of structs
val arrayWithStructsGroupings = agg.groupingExpressions.exists(e =>
TrampolineUtil.dataTypeExistsRecursively(e.dataType,
dt => dt match {
case ArrayType(_: StructType, _) => true
case _ => false
})
)
if (arrayWithStructsGroupings) {
willNotWorkOnGpu("ArrayTypes with Struct children in grouping expressions are not " +
"supported")
}

tagForReplaceMode()

if (agg.aggregateExpressions.exists(expr => expr.isDistinct)
Expand Down
6 changes: 3 additions & 3 deletions tools/generated_files/supportedExprs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ JsonToStructs,NS,`from_json`,This is disabled by default because parsing JSON fr
JsonTuple,S,`json_tuple`,None,project,json,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
JsonTuple,S,`json_tuple`,None,project,field,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA
JsonTuple,S,`json_tuple`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA
KnownFloatingPointNormalized,S, ,None,project,input,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
KnownFloatingPointNormalized,S, ,None,project,result,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
KnownFloatingPointNormalized,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S
KnownFloatingPointNormalized,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S
KnownNotNull,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS
KnownNotNull,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS
Lag,S,`lag`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NS,PS,NS
Expand Down Expand Up @@ -352,7 +352,7 @@ Multiply,S,`*`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,N
Multiply,S,`*`,None,AST,lhs,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NA,NA,NA,NA,NA
Multiply,S,`*`,None,AST,rhs,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NA,NA,NA,NA,NA
Multiply,S,`*`,None,AST,result,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NA,NA,NA,NA,NA
Murmur3Hash,S,`hash`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,PS,NS
Murmur3Hash,S,`hash`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NS,PS,NS
Murmur3Hash,S,`hash`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
NaNvl,S,`nanvl`,None,project,lhs,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
NaNvl,S,`nanvl`,None,project,rhs,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Expand Down

0 comments on commit e7817e4

Please sign in to comment.