diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index 9bbf93a6e26..02c78e296d4 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -213,3 +213,53 @@ def test_window_aggs_for_ranges_of_dates(data_gen): ' range between 1 preceding and 1 following) as sum_c_asc ' 'from window_agg_table' ) + + +''' + Spark will drop nulls when collecting, but seems GPU does not, so exceptions come up. +E Caused by: java.lang.AssertionError: value at 350 is null +E at ai.rapids.cudf.HostColumnVectorCore.assertsForGet(HostColumnVectorCore.java:228) +E at ai.rapids.cudf.HostColumnVectorCore.getInt(HostColumnVectorCore.java:254) +E at com.nvidia.spark.rapids.RapidsHostColumnVectorCore.getInt(RapidsHostColumnVectorCore.java:109) +E at org.apache.spark.sql.vectorized.ColumnarArray.getInt(ColumnarArray.java:128) + + Now set nullable to false to pass the tests, once native supports dropping nulls, will set it to true. +''' +collect_data_gen = [ + ('a', RepeatSeqGen(LongGen(), length=20)), + ('b', IntegerGen()), + ('c_int', IntegerGen(nullable=False)), + ('c_long', LongGen(nullable=False)), + ('c_time', DateGen(nullable=False)), + ('c_string', StringGen(nullable=False)), + ('c_float', FloatGen(nullable=False)), + ('c_decimal', DecimalGen(nullable=False, precision=8, scale=3)), + ('c_struct', StructGen(nullable=False, children = [ + ['child_int', IntegerGen()], + ['child_time', DateGen()], + ['child_string', StringGen()], + ['child_decimal', DecimalGen(nullable=False, precision=8, scale=3)]]))] + +# SortExec does not support array type, so sort the result locally. +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', [collect_data_gen], ids=idfn) +def test_window_aggs_for_rows_collect_list(data_gen): + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, data_gen, length=2048), + "window_collect_table", + 'select ' + ' collect_list(c_int) over ' + ' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_int, ' + ' collect_list(c_long) over ' + ' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_long, ' + ' collect_list(c_time) over ' + ' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_time, ' + ' collect_list(c_string) over ' + ' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_string, ' + ' collect_list(c_float) over ' + ' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_float, ' + ' collect_list(c_decimal) over ' + ' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_decimal, ' + ' collect_list(c_struct) over ' + ' (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_struct ' + 'from window_collect_table ')