From f26c0ff0c21c4faa3085acb27948067c28c06abb Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 29 Aug 2022 10:55:43 -0700 Subject: [PATCH] Added a ShuffleExec fallback test Signed-off-by: Raza Jafri --- docs/supported_ops.md | 28 +++++++++---------- .../src/main/python/hash_aggregate_test.py | 7 +++++ .../nvidia/spark/rapids/GpuOverrides.scala | 4 ++- tools/src/main/resources/supportedExprs.csv | 4 +-- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 811cbfa80a6..48b4783820c 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -7679,14 +7679,14 @@ are limited. S PS
UTC is only supported TZ for TIMESTAMP
S -S -S -S -S -PS
UTC is only supported TZ for child TIMESTAMP
-PS
UTC is only supported TZ for child TIMESTAMP
+ + + + PS
UTC is only supported TZ for child TIMESTAMP
-S + + + result @@ -7700,14 +7700,14 @@ are limited. S PS
UTC is only supported TZ for TIMESTAMP
S -S -S -S -S -PS
UTC is only supported TZ for child TIMESTAMP
-PS
UTC is only supported TZ for child TIMESTAMP
+ + + + PS
UTC is only supported TZ for child TIMESTAMP
-S + + + KnownNotNull diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index 90a9884359a..d9415abe496 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -354,6 +354,13 @@ def test_hash_grpby_sum_count_action(data_gen): lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b')) ) +@allow_non_gpu("ShuffleExchangeExec", "HashAggregateExec") +@pytest.mark.parametrize('data_gen', [_grpkey_nested_structs_with_array_child], ids=idfn) +def test_hash_grpby_sum_count_action_fallback(data_gen): + assert_gpu_fallback_collect( + lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b')), + 'ShuffleExchangeExec') + @pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn) def test_hash_reduction_sum_count_action(data_gen): assert_gpu_and_cpu_row_counts_equal( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index c376516458a..447bcc03e5b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1680,7 +1680,9 @@ object GpuOverrides extends Logging { }), expr[KnownFloatingPointNormalized]( "Tag to prevent redundant normalization", - ExprChecks.unaryProjectInputMatchesOutput(TypeSig.all, TypeSig.all), + ExprChecks.unaryProjectInputMatchesOutput( + (TypeSig.DOUBLE + TypeSig.FLOAT + TypeSig.ARRAY + TypeSig.commonCudfTypes).nested(), + (TypeSig.DOUBLE + TypeSig.FLOAT + TypeSig.ARRAY + TypeSig.commonCudfTypes).nested()), (a, conf, p, r) => new UnaryExprMeta[KnownFloatingPointNormalized](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = GpuKnownFloatingPointNormalized(child) diff --git a/tools/src/main/resources/supportedExprs.csv b/tools/src/main/resources/supportedExprs.csv index 673b2b86f9c..14e1d79909a 100644 --- a/tools/src/main/resources/supportedExprs.csv +++ b/tools/src/main/resources/supportedExprs.csv @@ -260,8 +260,8 @@ IsNull,S,`isnull`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS IsNull,S,`isnull`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA JsonToStructs,NS,`from_json`,This is disabled by default because parsing JSON from a column has a large number of issues and should be considered beta quality right now.,project,jsonStr,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA JsonToStructs,NS,`from_json`,This is disabled by default because parsing JSON from a column has a large number of issues and should be considered beta quality right now.,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NS,PS,NS,NA -KnownFloatingPointNormalized,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S -KnownFloatingPointNormalized,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S +KnownFloatingPointNormalized,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,NA,NA,NA,NA,PS,NA,NA,NA +KnownFloatingPointNormalized,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,NA,NA,NA,NA,PS,NA,NA,NA KnownNotNull,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS KnownNotNull,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS Lag,S,`lag`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NS,PS,NS