From a63b4832d8e094a46353635238591133d394091e Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Sat, 15 Jan 2022 08:37:51 +0800 Subject: [PATCH] Support max on single-level struct in aggregation context (#4434) * Support max on single-level struct in aggregation context * Refactor * Support min and max on single-level * Update test case after Cudf fixed bug about null Signed-off-by: Chong Gao --- docs/supported_ops.md | 16 +++--- .../src/main/python/hash_aggregate_test.py | 43 ++++++++++++++++ .../nvidia/spark/rapids/GpuOverrides.scala | 50 ++++++++++++++----- 3 files changed, 89 insertions(+), 20 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index a64f432983e..0feefaec297 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -15230,7 +15230,7 @@ are limited. NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT
NS @@ -15251,7 +15251,7 @@ are limited. NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT
NS @@ -15273,7 +15273,7 @@ are limited. NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT
NS @@ -15294,7 +15294,7 @@ are limited. NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT
NS @@ -15389,7 +15389,7 @@ are limited. NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT
NS @@ -15410,7 +15410,7 @@ are limited. NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT
NS @@ -15432,7 +15432,7 @@ are limited. NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT
NS @@ -15453,7 +15453,7 @@ are limited. NS NS -NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT
NS diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index f862a11d996..54e2f913784 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -1704,3 +1704,46 @@ def test_groupby_std_variance_partial_replace_fallback(data_gen, exist_classes=','.join(exist_clz), non_exist_classes=','.join(non_exist_clz), conf=local_conf) + +# +# test min max on single level structure +# +gens_for_max_min = [byte_gen, short_gen, int_gen, long_gen, + FloatGen(no_nans = True), DoubleGen(no_nans = True), + string_gen, boolean_gen, + date_gen, timestamp_gen, + DecimalGen(precision=12, scale=2), + DecimalGen(precision=36, scale=5), + null_gen] +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', gens_for_max_min, ids=idfn) +def test_min_max_for_single_level_struct(data_gen): + df_gen = [ + ('a', StructGen([ + ('aa', data_gen), + ('ab', data_gen)])), + ('b', RepeatSeqGen(IntegerGen(), length=20))] + + # test max + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, df_gen), + "hash_agg_table", + 'select b, max(a) from hash_agg_table group by b', + _no_nans_float_conf) + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, df_gen), + "hash_agg_table", + 'select max(a) from hash_agg_table', + _no_nans_float_conf) + + # test min + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, df_gen, length=1024), + "hash_agg_table", + 'select b, min(a) from hash_agg_table group by b', + _no_nans_float_conf) + assert_gpu_and_cpu_are_equal_sql( + lambda spark : gen_df(spark, df_gen, length=1024), + "hash_agg_table", + 'select min(a) from hash_agg_table', + _no_nans_float_conf) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 7637686d898..31401832bbb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2234,14 +2234,27 @@ object GpuOverrides extends Logging { }), expr[Max]( "Max aggregate operator", - ExprChecks.fullAgg( - TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable, - Seq(ParamCheck("input", - (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + ExprChecksImpl( + ExprChecks.reductionAndGroupByAgg( + // Max supports single level struct, e.g.: max(struct(string, string)) + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), + TypeSig.orderable, + Seq(ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) .withPsNote(TypeEnum.FLOAT, nanAggPsNote), - TypeSig.orderable)) - ), + TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts + ++ + ExprChecks.windowOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), + TypeSig.orderable, + Seq(ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), (max, conf, p, r) => new AggExprMeta[Max](max, conf, p, r) { override def tagAggForGpu(): Unit = { val dataType = max.child.dataType @@ -2256,14 +2269,27 @@ object GpuOverrides extends Logging { }), expr[Min]( "Min aggregate operator", - ExprChecks.fullAgg( - TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable, - Seq(ParamCheck("input", - (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + ExprChecksImpl( + ExprChecks.reductionAndGroupByAgg( + // Min supports single level struct, e.g.: max(struct(string, string)) + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), + TypeSig.orderable, + Seq(ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) .withPsNote(TypeEnum.FLOAT, nanAggPsNote), - TypeSig.orderable)) - ), + TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts + ++ + ExprChecks.windowOnly( + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL), + TypeSig.orderable, + Seq(ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL) + .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts), (a, conf, p, r) => new AggExprMeta[Min](a, conf, p, r) { override def tagAggForGpu(): Unit = { val dataType = a.child.dataType