Skip to content

Commit

Permalink
Support Collect-like Reduction Aggregations (#4992)
Browse files Browse the repository at this point in the history
Enables collect aggregations under reduction context in spark-rapids

Signed-off-by: sperlingxx [email protected]
  • Loading branch information
sperlingxx authored Apr 11, 2022
1 parent 99e1aab commit f991fd7
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 63 deletions.
68 changes: 34 additions & 34 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -14824,25 +14824,25 @@ are limited.
<td rowSpan="6">`collect_list`</td>
<td rowSpan="6">Collect a list of non-unique elements, not supported in reduction</td>
<td rowSpan="6">None</td>
<td rowSpan="2">reduction</td>
<td rowSpan="2">aggregation</td>
<td>input</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -14861,13 +14861,13 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="2">aggregation</td>
<td rowSpan="2">reduction</td>
<td>input</td>
<td>S</td>
<td>S</td>
Expand Down Expand Up @@ -14983,25 +14983,25 @@ are limited.
<td rowSpan="6">`collect_set`</td>
<td rowSpan="6">Collect a set of unique elements, not supported in reduction</td>
<td rowSpan="6">None</td>
<td rowSpan="2">reduction</td>
<td rowSpan="2">aggregation</td>
<td>input</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -15020,13 +15020,13 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="2">aggregation</td>
<td rowSpan="2">reduction</td>
<td>input</td>
<td>S</td>
<td>S</td>
Expand Down
49 changes: 37 additions & 12 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,18 +1038,18 @@ def test_first_last_reductions_nested_types(data_gen):
def test_generic_reductions(data_gen):
local_conf = copy_and_update(_no_nans_float_conf, {'spark.sql.legacy.allowParameterlessCount': 'true'})
assert_gpu_and_cpu_are_equal_collect(
# Coalesce and sort are to make sure that first and last, which are non-deterministic
# become deterministic
lambda spark : unary_op_df(spark, data_gen)\
.coalesce(1).selectExpr(
'min(a)',
'max(a)',
'first(a)',
'last(a)',
'count(a)',
'count()',
'count(1)'),
conf = local_conf)
# Coalesce and sort are to make sure that first and last, which are non-deterministic
# become deterministic
lambda spark : unary_op_df(spark, data_gen) \
.coalesce(1).selectExpr(
'min(a)',
'max(a)',
'first(a)',
'last(a)',
'count(a)',
'count()',
'count(1)'),
conf=local_conf)

@pytest.mark.parametrize('data_gen', all_gen + _nested_gens, ids=idfn)
def test_count(data_gen):
Expand Down Expand Up @@ -1083,6 +1083,31 @@ def test_arithmetic_reductions(data_gen):
'avg(a)'),
conf = _no_nans_float_conf)

@pytest.mark.parametrize('data_gen',
non_nan_all_basic_gens + decimal_gens + _nested_gens,
ids=idfn)
def test_collect_list_reductions(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr('collect_list(a)'),
conf=_no_nans_float_conf)

_struct_only_nested_gens = [all_basic_struct_gen,
StructGen([['child0', byte_gen], ['child1', all_basic_struct_gen]]),
StructGen([])]
@pytest.mark.parametrize('data_gen',
non_nan_all_basic_gens + decimal_gens + _struct_only_nested_gens,
ids=idfn)
def test_collect_set_reductions(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr('sort_array(collect_set(a))'),
conf=_no_nans_float_conf)

def test_collect_empty():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.sql("select collect_list(null)"))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.sql("select collect_set(null)"))

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen + _nested_gens, ids=idfn)
def test_groupby_first_last(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3218,8 +3218,7 @@ object GpuOverrides extends Logging {
}),
expr[CollectList](
"Collect a list of non-unique elements, not supported in reduction",
// GpuCollectList is not yet supported in Reduction context.
ExprChecks.aggNotReduction(
ExprChecks.fullAgg(
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all),
Expand Down Expand Up @@ -3249,10 +3248,7 @@ object GpuOverrides extends Logging {
}),
expr[CollectSet](
"Collect a set of unique elements, not supported in reduction",
// GpuCollectSet is not yet supported in Reduction context.
// Compared to CollectList, ArrayType and MapType are NOT supported in GpuCollectSet
// because underlying cuDF operator drop_list_duplicates doesn't support LIST type.
ExprChecks.aggNotReduction(
ExprChecks.fullAgg(
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.all),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExprId, ImplicitCastInputTypes, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.{ArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -352,32 +352,39 @@ class CudfMin(override val dataType: DataType) extends CudfAggregate {
}

class CudfCollectList(override val dataType: DataType) extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ =>
throw new UnsupportedOperationException("CollectList is not yet supported in reduction")
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
(col: cudf.ColumnVector) => col.reduce(ReductionAggregation.collectList(), DType.LIST)
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.collectList()
override val name: String = "CudfCollectList"
}

class CudfMergeLists(override val dataType: DataType) extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ =>
throw new UnsupportedOperationException("MergeLists is not yet supported in reduction")
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
(col: cudf.ColumnVector) => col.reduce(ReductionAggregation.mergeLists(), DType.LIST)
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.mergeLists()
override val name: String = "CudfMergeLists"
}

class CudfCollectSet(override val dataType: DataType) extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ =>
throw new UnsupportedOperationException("CollectSet is not yet supported in reduction")
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
(col: cudf.ColumnVector) => {
val collectSet = ReductionAggregation.collectSet(
NullPolicy.EXCLUDE, NullEquality.EQUAL, NaNEquality.UNEQUAL)
col.reduce(collectSet, DType.LIST)
}
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.collectSet(NullPolicy.EXCLUDE, NullEquality.EQUAL, NaNEquality.UNEQUAL)
override val name: String = "CudfCollectSet"
}

class CudfMergeSets(override val dataType: DataType) extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ =>
throw new UnsupportedOperationException("CudfMergeSets is not yet supported in reduction")
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar =
(col: cudf.ColumnVector) => {
val mergeSets = ReductionAggregation.mergeSets(NullEquality.EQUAL, NaNEquality.UNEQUAL)
col.reduce(mergeSets, DType.LIST)
}
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.mergeSets(NullEquality.EQUAL, NaNEquality.UNEQUAL)
override val name: String = "CudfMergeSets"
Expand Down Expand Up @@ -1578,8 +1585,9 @@ trait GpuCollectBase
// WINDOW FUNCTION
override val windowInputProjection: Seq[Expression] = Seq(child)

// Make them lazy to avoid being initialized when creating a GpuCollectOp.
override lazy val initialValues: Seq[Expression] = throw new UnsupportedOperationException
override val initialValues: Seq[Expression] = {
Seq(GpuLiteral.create(new GenericArrayData(Array.empty[Any]), dataType))
}

override val inputProjection: Seq[Expression] = Seq(child)

Expand Down

0 comments on commit f991fd7

Please sign in to comment.