Skip to content

Commit

Permalink
Add integration tests for collect_list.
Browse files Browse the repository at this point in the history
Add integration tests for collect_list with windowing.

Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman committed Jan 28, 2021
1 parent c5c982d commit 071910e
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,53 @@ def test_window_aggs_for_ranges_of_dates(data_gen):
' range between 1 preceding and 1 following) as sum_c_asc '
'from window_agg_table'
)


'''
Spark will drop nulls when collecting, but seems GPU does not, so exceptions come up.
E Caused by: java.lang.AssertionError: value at 350 is null
E at ai.rapids.cudf.HostColumnVectorCore.assertsForGet(HostColumnVectorCore.java:228)
E at ai.rapids.cudf.HostColumnVectorCore.getInt(HostColumnVectorCore.java:254)
E at com.nvidia.spark.rapids.RapidsHostColumnVectorCore.getInt(RapidsHostColumnVectorCore.java:109)
E at org.apache.spark.sql.vectorized.ColumnarArray.getInt(ColumnarArray.java:128)
Now set nullable to false to pass the tests, once native supports dropping nulls, will set it to true.
'''
collect_data_gen = [
('a', RepeatSeqGen(LongGen(), length=20)),
('b', IntegerGen()),
('c_int', IntegerGen(nullable=False)),
('c_long', LongGen(nullable=False)),
('c_time', DateGen(nullable=False)),
('c_string', StringGen(nullable=False)),
('c_float', FloatGen(nullable=False)),
('c_decimal', DecimalGen(nullable=False, precision=8, scale=3)),
('c_struct', StructGen(nullable=False, children = [
['child_int', IntegerGen()],
['child_time', DateGen()],
['child_string', StringGen()],
['child_decimal', DecimalGen(nullable=False, precision=8, scale=3)]]))]

# SortExec does not support array type, so sort the result locally.
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [collect_data_gen], ids=idfn)
def test_window_aggs_for_rows_collect_list(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen, length=2048),
"window_collect_table",
'select '
' collect_list(c_int) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_int, '
' collect_list(c_long) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_long, '
' collect_list(c_time) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_time, '
' collect_list(c_string) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_string, '
' collect_list(c_float) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_float, '
' collect_list(c_decimal) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_decimal, '
' collect_list(c_struct) over '
' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_struct '
'from window_collect_table ')

0 comments on commit 071910e

Please sign in to comment.