Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support executing collect_list on GPU with windowing. #1548

Merged
merged 15 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Year"></a>spark.rapids.sql.expression.Year|`year`|Returns the year from a date or timestamp|true|None|
<a name="sql.expression.AggregateExpression"></a>spark.rapids.sql.expression.AggregateExpression| |Aggregate expression|true|None|
<a name="sql.expression.Average"></a>spark.rapids.sql.expression.Average|`avg`, `mean`|Average aggregate operator|true|None|
<a name="sql.expression.CollectList"></a>spark.rapids.sql.expression.CollectList|`collect_list`|Collect a list of elements, now only supported by windowing.|false|This is disabled by default because for now the GPU collects null values to a list, but Spark does not. This will be fixed in future releases.|
<a name="sql.expression.Count"></a>spark.rapids.sql.expression.Count|`count`|Count aggregate operator|true|None|
<a name="sql.expression.First"></a>spark.rapids.sql.expression.First|`first_value`, `first`|first aggregate operator|true|None|
<a name="sql.expression.Last"></a>spark.rapids.sql.expression.Last|`last`, `last_value`|last aggregate operator|true|None|
Expand Down
153 changes: 143 additions & 10 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,9 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
</table>
Expand Down Expand Up @@ -15449,7 +15449,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15491,7 +15491,7 @@ Accelerator support is described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15675,7 +15675,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15717,7 +15717,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -15739,7 +15739,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15781,7 +15781,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -15803,7 +15803,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15845,7 +15845,7 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15984,6 +15984,139 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<td rowSpan="6">CollectList</td>
<td rowSpan="6">`collect_list`</td>
<td rowSpan="6">Collect a list of elements, now only supported by windowing.</td>
<td rowSpan="6">This is disabled by default because for now the GPU collects null values to a list, but Spark does not. This will be fixed in future releases.</td>
<td rowSpan="2">aggregation</td>
<td>input</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="2">reduction</td>
<td>input</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="2">window</td>
<td>input</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S*</td>
<td>S</td>
<td>S*</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT)</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="6">Count</td>
<td rowSpan="6">`count`</td>
<td rowSpan="6">Count aggregate operator</td>
Expand Down
63 changes: 63 additions & 0 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,66 @@ def test_window_aggs_for_ranges_of_dates(data_gen):
' range between 1 preceding and 1 following) as sum_c_asc '
'from window_agg_table'
)


def _gen_data_for_collect(nullable=True):
return [
('a', RepeatSeqGen(LongGen(), length=20)),
('b', IntegerGen()),
('c_int', IntegerGen(nullable=nullable)),
('c_long', LongGen(nullable=nullable)),
('c_time', DateGen(nullable=nullable)),
('c_string', StringGen(nullable=nullable)),
('c_float', FloatGen(nullable=nullable)),
('c_decimal', DecimalGen(nullable=nullable, precision=8, scale=3)),
('c_struct', StructGen(nullable=nullable, children=[
['child_int', IntegerGen()],
['child_time', DateGen()],
['child_string', StringGen()],
['child_decimal', DecimalGen(precision=8, scale=3)]]))]


_collect_sql_string =\
'''
select
collect_list(c_int) over
(partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_int,
collect_list(c_long) over
(partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_long,
collect_list(c_time) over
(partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_time,
collect_list(c_string) over
(partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_string,
collect_list(c_float) over
(partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_float,
collect_list(c_decimal) over
(partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_decimal,
collect_list(c_struct) over
(partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_struct
from window_collect_table
'''

# SortExec does not support array type, so sort the result locally.
@ignore_order(local=True)
@pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/1638")
def test_window_aggs_for_rows_collect_list():
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, _gen_data_for_collect(), length=2048),
"window_collect_table",
_collect_sql_string,
{'spark.rapids.sql.expression.CollectList': 'true'})


'''
Spark will drop nulls when collecting, but seems GPU does not yet, so exceptions come up.
Now set nullable to false to verify the current functionality without null values.
Once native supports dropping nulls, will enable the tests above and remove this one.
'''
# SortExec does not support array type, so sort the result locally.
@ignore_order(local=True)
def test_window_aggs_for_rows_collect_list_no_nulls():
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, _gen_data_for_collect(False), length=2048),
"window_collect_table",
_collect_sql_string,
{'spark.rapids.sql.expression.CollectList': 'true'})
Original file line number Diff line number Diff line change
Expand Up @@ -775,11 +775,11 @@ object GpuOverrides {
"\"window\") of rows",
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all,
Seq(ParamCheck("windowFunction",
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all),
ParamCheck("windowSpec",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL,
Expand Down Expand Up @@ -1684,11 +1684,13 @@ object GpuOverrides {
expr[AggregateExpression](
"Aggregate expression",
ExprChecks.fullAgg(
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all,
Seq(ParamCheck(
"aggFunc",
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all)),
Some(RepeatingParamCheck("filter", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(a, conf, p, r) => new ExprMeta[AggregateExpression](a, conf, p, r) {
Expand Down Expand Up @@ -2264,6 +2266,22 @@ object GpuOverrides {
override def convertToGpu(child: Expression): GpuExpression =
GpuMakeDecimal(child, a.precision, a.scale, a.nullOnOverflow)
}),
expr[CollectList](
"Collect a list of elements, now only supported by windowing.",
// It should be 'fullAgg' eventually but now only support windowing,
// so 'aggNotGroupByOrReduction'
ExprChecks.aggNotGroupByOrReduction(
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.all),
Seq(ParamCheck("input",
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
TypeSig.all))),
(c, conf, p, r) => new ExprMeta[CollectList](c, conf, p, r) {
override def convertToGpu(): GpuExpression = GpuCollectList(
childExprs.head.convertToGpu(), c.mutableAggBufferOffset, c.inputAggBufferOffset)
}).disabledByDefault("for now the GPU collects null values to a list, but Spark does not." +
revans2 marked this conversation as resolved.
Show resolved Hide resolved
" This will be fixed in future releases."),
expr[ScalarSubquery](
"Subquery that will return only one row and one column",
ExprChecks.projectOnly(
Expand Down Expand Up @@ -2604,7 +2622,11 @@ object GpuOverrides {
(expand, conf, p, r) => new GpuExpandExecMeta(expand, conf, p, r)),
exec[WindowExec](
"Window-operator backend",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL, TypeSig.all),
ExecChecks(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL) +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all),
(windowOp, conf, p, r) =>
new GpuWindowExecMeta(windowOp, conf, p, r)
),
Expand Down
Loading