From 33db86e7f4d441111ee1e594f0cd6492fac45cc1 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 2 Dec 2021 12:30:27 -0600 Subject: [PATCH 1/5] Full support for SUM overflow detection on decimal Signed-off-by: Robert (Bobby) Evans --- docs/compatibility.md | 83 ++-- docs/supported_ops.md | 8 +- .../src/main/python/hash_aggregate_test.py | 82 +++- .../src/main/python/window_function_test.py | 8 +- .../spark/sql/rapids/aggregate/GpuSum.scala | 11 +- .../spark/sql/rapids/aggregate/GpuSum.scala | 13 +- .../com/nvidia/spark/rapids/GpuCast.scala | 30 +- .../nvidia/spark/rapids/GpuOverrides.scala | 18 +- .../nvidia/spark/rapids/GpuWindowExec.scala | 8 +- .../spark/rapids/GpuWindowExpression.scala | 147 +++++- .../com/nvidia/spark/rapids/aggregate.scala | 93 +--- .../spark/rapids/basicPhysicalOperators.scala | 46 +- .../spark/sql/rapids/AggregateFunctions.scala | 420 ++++++++++++++++-- .../apache/spark/sql/rapids/arithmetic.scala | 43 +- 14 files changed, 710 insertions(+), 300 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 125256fbe6e..a3e9862588e 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -109,63 +109,31 @@ a few operations that we cannot support to the same degree as Spark can on the C ### Decimal Sum Aggregation -When Apache Spark does a sum aggregation on decimal values it will store the result in a value -with a precision that is the input precision + 10, but with a maximum precision of 38. The table -below shows the number of rows/values in an aggregation before an overflow is possible, -and the number of rows/values in the aggregation before an overflow might not be detected. -The numbers are for Spark 3.1.0 and above after a number of fixes were put in place, please see +A number of fixes for overflow detection went into Spark 3.1.0. Please see [SPARK-28067](https://issues.apache.org/jira/browse/SPARK-28067) and -[SPARK-32018](https://issues.apache.org/jira/browse/SPARK-32018) for more information. -Please also note that these are for the worst case situations, meaning all the values in the sum -were either the largest or smallest values possible to be stored in the input type. In the common -case, where the numbers are smaller, or vary between positive and negative values, many more -rows/values can be processed without any issues. - -|Input Precision|Number of values before overflow is possible|Maximum number of values for guaranteed overflow detection (Spark CPU)|Maximum number of values for guaranteed overflow detection (RAPIDS GPU)| -|---------------|------------------------------|------------|-------------| -|1 |11,111,111,111 |2,049,638,219,301,061,290 |Same as CPU | -|2 |10,101,010,101 |186,330,738,118,278,299 |Same as CPU | -|3 |10,010,010,010 |18,465,199,272,982,534 |Same as CPU | -|4 |10,001,000,100 |1,844,848,892,260,181 |Same as CPU | -|5 |10,000,100,001 |184,459,285,329,948 |Same as CPU | -|6 |10,000,010,000 |18,436,762,510,472 |Same as CPU | -|7 |10,000,001,000 |1,834,674,590,838 |Same as CPU | -|8 |10,000,000,100 |174,467,442,481 |Same as CPU | -|9 |10,000,000,010 |Unlimited |Unlimited | -|10 - 19 |10,000,000,000 |Unlimited |Unlimited | -|20 |10,000,000,000 |Unlimited |3,402,823,659,209,384,634 | -|21 |10,000,000,000 |Unlimited |340,282,356,920,938,463 | -|22 |10,000,000,000 |Unlimited |34,028,226,692,093,846 | -|23 |10,000,000,000 |Unlimited |3,402,813,669,209,384 | -|24 |10,000,000,000 |Unlimited |340,272,366,920,938 | -|25 |10,000,000,000 |Unlimited |34,018,236,692,093 | -|26 |10,000,000,000 |Unlimited |3,392,823,669,209 | -|27 |10,000,000,000 |Unlimited |330,282,366,920 | -|28 |10,000,000,000 |Unlimited |24,028,236,692 | -|29 |1,000,000,000 |Unlimited |Falls back to CPU | -|30 |100,000,000 |Unlimited |Falls back to CPU | -|31 |10,000,000 |Unlimited |Falls back to CPU | -|32 |1,000,000 |Unlimited |Falls back to CPU | -|33 |100,000 |Unlimited |Falls back to CPU | -|34 |10,00 |Unlimited |Falls back to CPU | -|35 |1,000 |Unlimited |Falls back to CPU | -|36 |100 |Unlimited |Falls back to CPU | -|37 |10 |Unlimited |Falls back to CPU | -|38 |1 |Unlimited |Falls back to CPU | - -For an input precision of 9 and above, Spark will do the aggregations as a `BigDecimal` -value which is slow, but guarantees that any overflow can be detected. For inputs with a -precision of 8 or below Spark will internally do the calculations as a long value, 64-bits. -When the precision is 8, you would need at least 174-billion values/rows contributing to a -single aggregation result, and even then all the values would need to be either the largest -or the smallest value possible to be stored in the type before the overflow is no longer detected. - -For the RAPIDS Accelerator we only have access to at most a 128-bit value to store the results -in and still detect overflow. Because of this we cannot guarantee overflow detection in all -cases. In some cases we can guarantee unlimited overflow detection because of the maximum number of -values that RAPIDS will aggregate in a single batch. But even in the worst cast for a decimal value -with a precision of 28 the user would still have to aggregate so many values that it overflows 2.4 -times over before we are no longer able to detect it. +[SPARK-32018](https://issues.apache.org/jira/browse/SPARK-32018) for more detailed information. +Some of these fixes we were able to back port, but some of them require Spark 3.1.0 or above to +fully be able to detect overflow in all cases. As such on versions of Spark older than 3.1.0 for +large decimal values there is the possibility of data corruption in some corner cases. +This is true for both the CPU and GPU implementations, but there are fewer of these cases for the +GPU. If this concerns you, you should upgrade to Spark 3.1.0 or above. + +When Apache Spark does a sum aggregation on decimal values it will store the result in a value +with a precision that is the input precision + 10, but with a maximum precision of 38. +For an input precision of 9 and above, Spark will do the aggregations as a java `BigDecimal` +value which is slow, but guarantees that any overflow can be detected because it can work with +effectively unlimited precision. For inputs with a precision of 8 or below Spark will internally do +the calculations as a long value, 64-bits. When the precision is 8, you would need at least +174,467,442,482 values/rows contributing to a single aggregation result before the overflow is no +longer detected. Even then all the values would need to be either the largest or the smallest value +possible to be stored in the type for the overflow to cause data corruption. + +For the RAPIDS Accelerator we don't have direct access to unlimited precision for our calculations +like the CPU does. For input values with a precision of 8 and below we follow Spark and process the +data the same way, as a 64-bit value. For larger values we will do extra calculations looking at the +higher order digits to be able to detect overflow in all cases. But because of this you may see +some performance differences depending on the input precision used. The differences will show up +when going from an input precision of 8 to 9 and again when going from an input precision of 28 to 29. ### Decimal Average @@ -175,8 +143,7 @@ have. It also inherits some issues from Spark itself. See https://issues.apache.org/jira/browse/SPARK-37024 for a detailed description of some issues with average in Spark. -In order to be able to guarantee overflow detection on the sum with at least 100-billion values -and to be able to guarantee doing the divide with half up rounding at the end we only support +In order to be able to guarantee doing the divide with half up rounding at the end we only support average on input values with a precision of 23 or below. This is 38 - 10 for the sum guarantees and then 5 less to be able to shift the left-hand side of the divide enough to get a correct answer that can be rounded to the result that Spark would produce. diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 1aac3b00cc2..051c9a1906d 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -386,13 +386,13 @@ Accelerator supports are described below. S PS
UTC is only supported TZ for TIMESTAMP
S -PS
max DECIMAL precision of 18
+S S NS NS -PS
max child DECIMAL precision of 18;
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
-PS
max child DECIMAL precision of 18;
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
-PS
max child DECIMAL precision of 18;
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT
NS diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index 90d062fc816..d35035ae664 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -265,33 +265,59 @@ def get_params(init_list, marked_params=[]): ('b', decimal_gen_20_2), ('c', decimal_gen_20_2)] +# NOTE on older versions of Spark decimal 38 causes the CPU to crash +# instead of detect overflows, we have versions of this for both +# 36 and 38 so we can get some coverage for old versions and full +# coverage for newer versions +_grpkey_short_very_big_decimals = [ + ('a', RepeatSeqGen(short_gen, length=50)), + ('b', decimal_gen_36_5), + ('c', decimal_gen_36_5)] + +_grpkey_short_very_big_neg_scale_decimals = [ + ('a', RepeatSeqGen(short_gen, length=50)), + ('b', decimal_gen_36_neg5), + ('c', decimal_gen_36_neg5)] + +_grpkey_short_full_decimals = [ + ('a', RepeatSeqGen(short_gen, length=50)), + ('b', decimal_gen_38_0), + ('c', decimal_gen_38_0)] + +_grpkey_short_full_neg_scale_decimals = [ + ('a', RepeatSeqGen(short_gen, length=50)), + ('b', decimal_gen_38_neg10), + ('c', decimal_gen_38_neg10)] + + _init_list_no_nans_with_decimal = _init_list_no_nans + [ _grpkey_small_decimals] _init_list_no_nans_with_decimalbig = _init_list_no_nans + [ - _grpkey_small_decimals, _grpkey_short_mid_decimals, _grpkey_short_big_decimals] + _grpkey_small_decimals, _grpkey_short_mid_decimals, + _grpkey_short_big_decimals, _grpkey_short_very_big_decimals, + _grpkey_short_very_big_neg_scale_decimals] + +_init_list_full_decimal = [_grpkey_short_full_decimals, + _grpkey_short_full_neg_scale_decimals] -#TODO when we can support sum on larger types https://github.com/NVIDIA/spark-rapids/issues/3944 -# we should move to a larger type and use a smaller count so we can avoid the long CPU run -# then we should look at splitting the reduction up so we do half on the CPU and half on the GPU -# So we can test compatabiliy too (but without spending even longer computing) -def test_hash_reduction_decimal_overflow_sum(): +#Any smaller precision takes way too long to process on the CPU +@pytest.mark.parametrize('precision', [38, 37, 36, 35, 34, 33, 32, 31, 30], ids=idfn) +def test_hash_reduction_decimal_overflow_sum(precision): + constant = '9' * precision + count = pow(10, 38 - precision) assert_gpu_and_cpu_are_equal_collect( - # Spark adds +10 to precision so we need 10-billion entries before overflow is even possible. - # we use 10-billion and 2 to make sure we hit the overflow. - lambda spark: spark.range(0, 10000000002, 1, 48)\ - .selectExpr("CAST('9999999999' as Decimal(10, 0)) as a")\ - .selectExpr("SUM(a)"), - # set the batch size small because we can have limited GPU memory and the first select - # doubles the size of the batch - conf = {'spark.rapids.sql.batchSizeBytes': '64m'}) + lambda spark: spark.range(count)\ + .selectExpr("CAST('{}' as Decimal({}, 0)) as a".format(constant, precision))\ + .selectExpr("SUM(a)")) @pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn) def test_hash_grpby_sum_count_action(data_gen): assert_gpu_and_cpu_row_counts_equal( lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b')) ) + @pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn) def test_hash_reduction_sum_count_action(data_gen): assert_gpu_and_cpu_row_counts_equal( @@ -307,19 +333,41 @@ def test_hash_reduction_sum_count_action(data_gen): def test_hash_grpby_sum(data_gen, conf): assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b')), - conf=conf - ) + conf = copy_and_update(allow_negative_scale_of_decimal_conf, conf)) +@pytest.mark.skipif(is_before_spark_311(), reason="SUM overflows for CPU were fixed in Spark 3.1.1") +@shuffle_test @approximate_float @ignore_order @incompat -@pytest.mark.parametrize('data_gen', numeric_gens + decimal_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', _init_list_full_decimal, ids=idfn) +@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn) +def test_hash_grpby_sum_full_decimal(data_gen, conf): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b')), + conf = copy_and_update(allow_negative_scale_of_decimal_conf, conf)) + +@approximate_float +@ignore_order +@incompat +@pytest.mark.parametrize('data_gen', numeric_gens + decimal_gens + [decimal_gen_20_2, decimal_gen_30_2, decimal_gen_36_5, decimal_gen_36_neg5], ids=idfn) @pytest.mark.parametrize('conf', get_params(_confs_with_nans, params_markers_for_confs_nans), ids=idfn) def test_hash_reduction_sum(data_gen, conf): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen, length=100).selectExpr("SUM(a)"), conf = copy_and_update(allow_negative_scale_of_decimal_conf, conf)) +@pytest.mark.skipif(is_before_spark_311(), reason="SUM overflows for CPU were fixed in Spark 3.1.1") +@approximate_float +@ignore_order +@incompat +@pytest.mark.parametrize('data_gen', numeric_gens + decimal_gens + [decimal_gen_38_0, decimal_gen_38_10, decimal_gen_38_neg10], ids=idfn) +@pytest.mark.parametrize('conf', get_params(_confs_with_nans, params_markers_for_confs_nans), ids=idfn) +def test_hash_reduction_sum_full_decimal(data_gen, conf): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, data_gen, length=100).selectExpr("SUM(a)"), + conf = copy_and_update(allow_negative_scale_of_decimal_conf, conf)) + @approximate_float @ignore_order @incompat diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index 97c045ed5e5..0bf6f1f0bc5 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -164,7 +164,7 @@ def test_decimal128_count_window_no_part(data_gen): conf = allow_negative_scale_of_decimal_conf) @ignore_order -@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', decimal_gens + decimal_128_gens, ids=idfn) def test_decimal_sum_window(data_gen): assert_gpu_and_cpu_are_equal_sql( lambda spark: three_col_df(spark, byte_gen, LongRangeGen(), data_gen), @@ -177,7 +177,7 @@ def test_decimal_sum_window(data_gen): conf = allow_negative_scale_of_decimal_conf) @ignore_order -@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', decimal_gens + decimal_128_gens, ids=idfn) def test_decimal_sum_window_no_part(data_gen): assert_gpu_and_cpu_are_equal_sql( lambda spark: two_col_df(spark, LongRangeGen(), data_gen), @@ -191,7 +191,7 @@ def test_decimal_sum_window_no_part(data_gen): @ignore_order -@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', decimal_gens + decimal_128_gens, ids=idfn) def test_decimal_running_sum_window(data_gen): assert_gpu_and_cpu_are_equal_sql( lambda spark: three_col_df(spark, byte_gen, LongRangeGen(), data_gen), @@ -205,7 +205,7 @@ def test_decimal_running_sum_window(data_gen): {'spark.rapids.sql.batchSizeBytes': '100'})) @ignore_order -@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', decimal_gens + decimal_128_gens, ids=idfn) def test_decimal_running_sum_window_no_part(data_gen): assert_gpu_and_cpu_are_equal_sql( lambda spark: two_col_df(spark, LongRangeGen(), data_gen), diff --git a/sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/aggregate/GpuSum.scala b/sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/aggregate/GpuSum.scala index 4be0e643cf7..09bbf280896 100644 --- a/sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/aggregate/GpuSum.scala +++ b/sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/aggregate/GpuSum.scala @@ -16,13 +16,6 @@ package org.apache.spark.sql.rapids.aggregate -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.rapids.GpuSumBase -import org.apache.spark.sql.types.DataType - -case class GpuSum(child: Expression, - resultType: DataType, - failOnErrorOverride: Boolean = SQLConf.get.ansiEnabled) extends GpuSumBase { - override val extraDecimalOverflowChecks: Boolean = false +object GpuSumDefaults { + val hasIsEmptyField: Boolean = false } \ No newline at end of file diff --git a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/aggregate/GpuSum.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/aggregate/GpuSum.scala index ae748013695..057f588eddf 100644 --- a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/aggregate/GpuSum.scala +++ b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/aggregate/GpuSum.scala @@ -16,13 +16,6 @@ package org.apache.spark.sql.rapids.aggregate -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.rapids.GpuSumBase -import org.apache.spark.sql.types.DataType - -case class GpuSum(child: Expression, - resultType: DataType, - failOnErrorOverride: Boolean = SQLConf.get.ansiEnabled) extends GpuSumBase { - override val extraDecimalOverflowChecks: Boolean = true -} \ No newline at end of file +object GpuSumDefaults { + val hasIsEmptyField: Boolean = true +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index dec844f5f10..87289552899 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -1345,24 +1345,30 @@ object GpuCast extends Arm { } } + def fixDecimalBounds(input: ColumnView, + outOfBounds: ColumnView, + ansiMode: Boolean): ColumnVector = { + if (ansiMode) { + withResource(outOfBounds.any()) { isAny => + if (isAny.isValid && isAny.getBoolean) { + throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE) + } + } + input.copyToColumnVector() + } else { + withResource(Scalar.fromNull(input.getType)) { nullVal => + outOfBounds.ifElse(nullVal, input) + } + } + } + def checkNFixDecimalBounds( input: ColumnView, to: DecimalType, ansiMode: Boolean): ColumnVector = { assert(input.getType.isDecimalType) withResource(DecimalUtil.outOfBounds(input, to)) { outOfBounds => - if (ansiMode) { - withResource(outOfBounds.any()) { isAny => - if (isAny.isValid && isAny.getBoolean) { - throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE) - } - } - input.copyToColumnVector() - } else { - withResource(Scalar.fromNull(input.getType)) { nullVal => - outOfBounds.ifElse(nullVal, input) - } - } + fixDecimalBounds(input, outOfBounds, ansiMode) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 919ffc775e3..be36eba8751 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -56,7 +56,6 @@ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.hive.rapids.GpuHiveOverrides import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids._ -import org.apache.spark.sql.rapids.aggregate.GpuSum import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand import org.apache.spark.sql.rapids.execution._ import org.apache.spark.sql.rapids.execution.python._ @@ -2264,21 +2263,6 @@ object GpuOverrides extends Logging { override def tagAggForGpu(): Unit = { val inputDataType = a.child.dataType checkAndTagFloatAgg(inputDataType, conf, this) - - a.dataType match { - case _: DecimalType => - val unboundPrecision = a.child.dataType.asInstanceOf[DecimalType].precision + 10 - if (unboundPrecision > DType.DECIMAL128_MAX_PRECISION) { - if (conf.needDecimalGuarantees) { - willNotWorkOnGpu("overflow checking on sum would need " + - s"a precision of $unboundPrecision to properly detect overflows") - } else { - logWarning("Decimal overflow guarantees disabled for " + - s"sum(${a.child.dataType}) produces ${a.dataType}") - } - } - case _ => // NOOP - } } override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = @@ -3678,7 +3662,7 @@ object GpuOverrides extends Logging { exec[SampleExec]( "The backend for the sample operator", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY + TypeSig.DECIMAL_64).nested(), TypeSig.all), + TypeSig.ARRAY + TypeSig.DECIMAL_128_FULL).nested(), TypeSig.all), (sample, conf, p, r) => new GpuSampleExecMeta(sample, conf, p, r) ), ShimLoader.getSparkShims.aqeShuffleReaderExec, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala index aeda8d8b394..fa873dbaac8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala @@ -322,11 +322,13 @@ object GpuWindowExec extends Arm { // First pass replace any operations that should be totally replaced. val replacePass = expr.transformDown { case GpuWindowExpression( - GpuAggregateExpression(rep: GpuReplaceWindowFunction, _, _, _, _), spec) => + GpuAggregateExpression(rep: GpuReplaceWindowFunction, _, _, _, _), spec) + if rep.shouldReplaceWindow(spec) => // We don't actually care about the GpuAggregateExpression because it is ignored // by our GPU window operations anyways. rep.windowReplacement(spec) - case GpuWindowExpression(rep: GpuReplaceWindowFunction, spec) => + case GpuWindowExpression(rep: GpuReplaceWindowFunction, spec) + if rep.shouldReplaceWindow(spec)=> rep.windowReplacement(spec) } // Second pass looks for GpuWindowFunctions and GpuWindowSpecDefinitions to build up @@ -372,6 +374,8 @@ object GpuWindowExec extends Arm { def isRunningWindow(spec: GpuWindowSpecDefinition): Boolean = spec match { case GpuWindowSpecDefinition(_, _, GpuSpecifiedWindowFrame(RowFrame, GpuSpecialFrameBoundary(UnboundedPreceding), GpuSpecialFrameBoundary(CurrentRow))) => true + case GpuWindowSpecDefinition(_, _, GpuSpecifiedWindowFrame(RowFrame, + GpuSpecialFrameBoundary(UnboundedPreceding), GpuLiteral(value, _))) if value == 0 => true case _ => false } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala index ada4000e9c4..547a6cc0d64 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.rapids.{GpuAggregateExpression, GpuCreateNamedStruct} +import org.apache.spark.sql.rapids.{GpuAdd, GpuAggregateExpression, GpuCreateNamedStruct} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -630,6 +630,12 @@ trait GpuReplaceWindowFunction extends GpuWindowFunction { * in the future. */ def windowReplacement(spec: GpuWindowSpecDefinition): Expression + + /** + * Return true if windowReplacement should be called to replace this GpuWindowFunction with + * something else. + */ + def shouldReplaceWindow(spec: GpuWindowSpecDefinition): Boolean = true } /** @@ -881,14 +887,17 @@ class BatchedRunningWindowBinaryFixer(val binOp: BinaryOp, val name: String) class SumBinaryFixer(toType: DataType, isAnsi: Boolean) extends BatchedRunningWindowFixer with Arm with Logging { private val name = "sum" - private val binOp = BinaryOp.ADD private var previousResult: Option[Scalar] = None + private var previousOverflow: Option[Scalar] = None - def updateState(finalOutputColumn: cudf.ColumnVector): Unit = { + def updateState(finalOutputColumn: cudf.ColumnVector, + wasOverflow: Option[cudf.ColumnVector]): Unit = { + val lastIndex = finalOutputColumn.getRowCount.toInt - 1 logDebug(s"$name: updateState from $previousResult to...") previousResult.foreach(_.close) - previousResult = - Some(finalOutputColumn.getScalarElement(finalOutputColumn.getRowCount.toInt - 1)) + previousResult = Some(finalOutputColumn.getScalarElement(lastIndex)) + previousOverflow.foreach(_.close()) + previousOverflow = wasOverflow.map(_.getScalarElement(lastIndex)) logDebug(s"$name: ... $previousResult") } @@ -911,8 +920,7 @@ class SumBinaryFixer(toType: DataType, isAnsi: Boolean) throw new IllegalArgumentException(s"Making a zero scalar for $other is not supported") } - override def fixUp(samePartitionMask: Either[cudf.ColumnVector, Boolean], - sameOrderMask: Option[Either[cudf.ColumnVector, Boolean]], + private[this] def fixUpNonDecimal(samePartitionMask: Either[cudf.ColumnVector, Boolean], windowedColumnOutput: cudf.ColumnView): cudf.ColumnVector = { logDebug(s"$name: fix up $previousResult $samePartitionMask") val ret = (previousResult, samePartitionMask) match { @@ -928,7 +936,7 @@ class SumBinaryFixer(toType: DataType, isAnsi: Boolean) } } withResource(nullsReplaced) { nullsReplaced => - nullsReplaced.binaryOp(binOp, prev, prev.getType) + nullsReplaced.binaryOp(BinaryOp.ADD, prev, prev.getType) } } else { // prev is NULL but NULL + something == NULL which we don't want @@ -948,7 +956,7 @@ class SumBinaryFixer(toType: DataType, isAnsi: Boolean) } } withResource(nullsReplaced) { nullsReplaced => - withResource(nullsReplaced.binaryOp(binOp, prev, prev.getType)) { updated => + withResource(nullsReplaced.binaryOp(BinaryOp.ADD, prev, prev.getType)) { updated => mask.ifElse(updated, windowedColumnOutput) } } @@ -958,15 +966,126 @@ class SumBinaryFixer(toType: DataType, isAnsi: Boolean) } } closeOnExcept(ret) { ret => - updateState(ret) + updateState(ret, None) + ret + } + } + + private[this] def fixUpDecimal(samePartitionMask: Either[cudf.ColumnVector, Boolean], + windowedColumnOutput: cudf.ColumnView, + dt: DecimalType): cudf.ColumnVector = { + logDebug(s"$name: fix up $previousResult $samePartitionMask") + val (ret, decimalOverflowOnAdd) = (previousResult, previousOverflow, samePartitionMask) match { + case (None, None, _) => + // The mask is all false so do nothing + withResource(Scalar.fromBool(false)) { falseVal => + closeOnExcept(ColumnVector.fromScalar(falseVal, + windowedColumnOutput.getRowCount.toInt)) { over => + (incRef(windowedColumnOutput), over) + } + } + case (Some(prev), Some(previousOver), scala.util.Right(mask)) => + if (mask) { + if (!prev.isValid) { + // So in the window operation we can have a null if all of the input values before it + // were also null or if we overflowed the result and inserted in a null. + // + // If we overflowed, then all of the output for this group should be null, but the + // overflow check code can handle inserting that, so just inc the ref count and return + // the overflow column. + // + // If we didn't overflow, and the input is null then + // prev is NULL but NULL + something == NULL which we don't want, so also + // just increment the reference count and go on. + closeOnExcept(ColumnVector.fromScalar(previousOver, + windowedColumnOutput.getRowCount.toInt)) { over => + (incRef(windowedColumnOutput), over) + } + } else { + // The previous didn't overflow, so now we need to do the add and check for overflow. + val nullsReplaced = withResource(windowedColumnOutput.isNull) { nulls => + withResource(makeZeroScalar(windowedColumnOutput.getType)) { zero => + nulls.ifElse(zero, windowedColumnOutput) + } + } + withResource(nullsReplaced) { nullsReplaced => + closeOnExcept(nullsReplaced.binaryOp(BinaryOp.ADD, prev, prev.getType)) { added => + (added, GpuAdd.didDecimalOverflow(nullsReplaced, prev, added)) + } + } + } + } else { + // The mask is all false so do nothing + withResource(Scalar.fromBool(false)) { falseVal => + closeOnExcept(ColumnVector.fromScalar(falseVal, + windowedColumnOutput.getRowCount.toInt)) { over => + (incRef(windowedColumnOutput), over) + } + } + } + case (Some(prev), Some(previousOver), scala.util.Left(mask)) => + if (prev.isValid) { + // The previous didn't overflow, so now we need to do the add and check for overflow. + val nullsReplaced = withResource(windowedColumnOutput.isNull) { nulls => + withResource(nulls.and(mask)) { shouldReplace => + withResource(makeZeroScalar(windowedColumnOutput.getType)) { zero => + shouldReplace.ifElse(zero, windowedColumnOutput) + } + } + } + withResource(nullsReplaced) { nullsReplaced => + withResource(nullsReplaced.binaryOp(BinaryOp.ADD, prev, prev.getType)) { added => + closeOnExcept(mask.ifElse(added, windowedColumnOutput)) { updated => + withResource(Scalar.fromBool(false)) { falseVal => + withResource(GpuAdd.didDecimalOverflow(nullsReplaced, prev, added)) { over => + (updated, mask.ifElse(over, falseVal)) + } + } + } + } + } + } else { + // So in the window operation we can have a null if all of the input values before it + // were also null or if we overflowed the result and inserted in a null. + // + // If we overflowed, then all of the output for this group should be null, but the + // overflow check code can handle inserting that, so just inc the ref count and return + // the overflow column. + // + // If we didn't overflow, and the input is null then + // prev is NULL but NULL + something == NULL which we don't want, so also + // just increment the reference count and go on. + closeOnExcept(ColumnVector.fromScalar(previousOver, + windowedColumnOutput.getRowCount.toInt)) { over => + (incRef(windowedColumnOutput), over) + } + } + case _ => + throw new IllegalStateException("INTERNAL ERROR: Should never have a situation where " + + "prev and previousOver do not match.") } + withResource(ret) { ret => + withResource(decimalOverflowOnAdd) { decimalOverflowOnAdd => + withResource(DecimalUtil.outOfBounds(ret, dt)) { valOutOfBounds => + withResource(valOutOfBounds.or(decimalOverflowOnAdd)) { outOfBounds => + closeOnExcept(GpuCast.fixDecimalBounds(ret, outOfBounds, isAnsi)) { replaced => + updateState(replaced, Some(outOfBounds)) + replaced + } + } + } + } + } + } + + override def fixUp(samePartitionMask: Either[cudf.ColumnVector, Boolean], + sameOrderMask: Option[Either[cudf.ColumnVector, Boolean]], + windowedColumnOutput: cudf.ColumnView): cudf.ColumnVector = { toType match { case dt: DecimalType => - withResource(ret) { ret => - GpuCast.checkNFixDecimalBounds(ret, dt, isAnsi) - } + fixUpDecimal(samePartitionMask, windowedColumnOutput, dt) case _ => - ret + fixUpNonDecimal(samePartitionMask, windowedColumnOutput) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index 448ddfe6be6..7ccb95ee267 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -188,6 +188,7 @@ class GpuHashAggregateIterator( metrics: GpuHashAggregateMetrics, configuredTargetBatchSize: Long) extends Iterator[ColumnarBatch] with Arm with AutoCloseable with Logging { + // Partial mode: // 1. boundInputReferences: picks column from raw input // 2. boundFinalProjections: is a pass-through of the agg buffer @@ -205,7 +206,6 @@ class GpuHashAggregateIterator( // (GpuAverage => CudfSum/CudfCount) // 3. boundResultReferences: project the result expressions Spark expects in the output. private case class BoundExpressionsModeAggregates( - boundInputReferences: Seq[GpuExpression], boundFinalProjections: Option[Seq[GpuExpression]], boundResultReferences: Seq[Expression]) @@ -277,13 +277,10 @@ class GpuHashAggregateIterator( /** Aggregate all input batches and place the results in the aggregatedBatches queue. */ private def aggregateInputBatches(): Unit = { - val aggHelper = new AggHelper(merge = false) + val aggHelper = new AggHelper(forceMerge = false) while (cbIter.hasNext) { - val (childBatch, isLastInputBatch) = withResource(cbIter.next()) { inputBatch => - val isLast = GpuColumnVector.isTaggedAsFinalBatch(inputBatch) - (processIncomingBatch(inputBatch), isLast) - } - withResource(childBatch) { _ => + withResource(cbIter.next()) { childBatch => + val isLastInputBatch = GpuColumnVector.isTaggedAsFinalBatch(childBatch) withResource(computeAggregate(childBatch, aggHelper)) { aggBatch => val batch = LazySpillableColumnarBatch(aggBatch, metrics.spillCallback, "aggbatch") // Avoid making batch spillable for the common case of the last and only batch @@ -385,7 +382,7 @@ class GpuHashAggregateIterator( wasBatchMerged } - private lazy val concatAndMergeHelper = new AggHelper(merge = true) + private lazy val concatAndMergeHelper = new AggHelper(forceMerge = true) /** * Concatenate batches together and perform a merge aggregation on the result. The input batches @@ -544,27 +541,6 @@ class GpuHashAggregateIterator( } } - /** Perform the initial projection on the input batch and extract the result columns */ - private def processIncomingBatch(batch: ColumnarBatch): ColumnarBatch = { - val aggTime = metrics.computeAggTime - val opTime = metrics.opTime - withResource(new NvtxWithMetrics("prep agg batch", NvtxColor.CYAN, aggTime, - opTime)) { _ => - val cols = boundExpressions.boundInputReferences.safeMap { ref => - val childCv = GpuExpressionsUtils.columnarEvalToColumn(ref, batch) - if (DataType.equalsStructurally(childCv.dataType, ref.dataType, ignoreNullability = true)) { - childCv - } else { - withResource(childCv) { childCv => - val rapidsType = GpuColumnVector.getNonNestedRapidsType(ref.dataType) - GpuColumnVector.from(childCv.getBase.castTo(rapidsType), ref.dataType) - } - } - } - new ColumnarBatch(cols.toArray, batch.numRows()) - } - } - /** * Concatenates batches after extracting them from `LazySpillableColumnarBatch` * @note the input batches are not closed as part of this operation @@ -603,30 +579,6 @@ class GpuHashAggregateIterator( val aggBufferAttributes = groupingAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - // Adapted from `AggregationIterator.initializeAggregateFunctions` in Spark: - // - we use the "imperative aggregate" way as it used bound expressions due to - // lack of support of codegen (like our case) - // - for partial/complete: we bind to the inputProjection as specified by each - // `GpuAggregateFunction` to the `inputAttributes` (see how those are defined) - // - for partial merge/final: it is the pass through case, we are getting as input - // the "agg buffer", and we are using `inputAggBufferAttributes` to match the Spark - // function. We still bind to `inputAttributes`, as those would be setup for pass-through - // in the partial merge/final cases. - val aggBound = aggregateExpressions.flatMap { agg => - agg.mode match { - case Partial | Complete => - agg.aggregateFunction.inputProjection - case PartialMerge | Final => - agg.aggregateFunction.inputAggBufferAttributes - case mode => - throw new NotImplementedError(s"can't translate ${mode}") - } - } - - val boundInputReferences = GpuBindReferences.bindGpuReferences( - groupingExpressions ++ aggBound, - inputAttributes) - val boundFinalProjections = if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) { val finalProjections = groupingExpressions ++ aggregateExpressions.map(_.aggregateFunction.evaluateExpression) @@ -659,7 +611,6 @@ class GpuHashAggregateIterator( groupingAttributes) } BoundExpressionsModeAggregates( - boundInputReferences, boundFinalProjections, boundResultReferences) } @@ -667,12 +618,12 @@ class GpuHashAggregateIterator( /** * Internal class used in `computeAggregates` for the pre, agg, and post steps * - * @param merge - if true, we are merging two pre-aggregated batches, so we should use + * @param forceMerge - if true, we are merging two pre-aggregated batches, so we should use * the merge steps for each aggregate function * @param isSorted - if the batch is sorted this is set to true and is passed to cuDF * as an optimization hint */ - class AggHelper(merge: Boolean, isSorted: Boolean = false) { + class AggHelper(forceMerge: Boolean, isSorted: Boolean = false) { // `CudfAggregate` instances to apply, either update or merge aggregates private val cudfAggregates = new mutable.ArrayBuffer[CudfAggregate]() @@ -694,10 +645,8 @@ class GpuHashAggregateIterator( private val postStepAttr = new mutable.ArrayBuffer[Attribute]() // we add the grouping expression first, which bind as pass-through - preStep ++= GpuBindReferences.bindGpuReferences( - groupingAttributes, groupingAttributes) - postStep ++= GpuBindReferences.bindGpuReferences( - groupingAttributes, groupingAttributes) + preStep ++= groupingExpressions + postStep ++= groupingAttributes postStepAttr ++= groupingAttributes postStepDataTypes ++= groupingExpressions.map(_.dataType) @@ -705,14 +654,14 @@ class GpuHashAggregateIterator( private var ix = groupingAttributes.length for (aggExp <- aggregateExpressions) { val aggFn = aggExp.aggregateFunction - if ((aggExp.mode == Partial || aggExp.mode == Complete) && !merge) { + if ((aggExp.mode == Partial || aggExp.mode == Complete) && !forceMerge) { val ordinals = (ix until ix + aggFn.updateAggregates.length) aggOrdinals ++= ordinals ix += ordinals.length val updateAggs = aggFn.updateAggregates postStepDataTypes ++= updateAggs.map(_.dataType) cudfAggregates ++= updateAggs - preStep ++= aggFn.aggBufferAttributes + preStep ++= aggFn.inputProjection postStep ++= aggFn.postUpdate postStepAttr ++= aggFn.postUpdateAttr } else { @@ -729,8 +678,11 @@ class GpuHashAggregateIterator( } // a bound expression that is applied before the cuDF aggregate - private val preStepBound = + private val preStepBound = if (forceMerge) { GpuBindReferences.bindGpuReferences(preStep, aggBufferAttributes) + } else { + GpuBindReferences.bindGpuReferences(preStep, inputAttributes) + } // a bound expression that is applied after the cuDF aggregate private val postStepBound = @@ -739,7 +691,8 @@ class GpuHashAggregateIterator( /** * Apply the "pre" step: preMerge for merge, or pass-through in the update case * @param toAggregateBatch - input (to the agg) batch from the child directly in the - * merge case, or from the `inputProjection` in the update case. + * merge + * case, or from the `inputProjection` in the update case. * @return a pre-processed batch that can be later cuDF aggregated */ def preProcess(toAggregateBatch: ColumnarBatch): ColumnarBatch = { @@ -1388,11 +1341,11 @@ case class GpuHashAggregateExec( private val inputAggBufferAttributes: Seq[Attribute] = { aggregateExpressions - // there're exactly four cases needs `inputAggBufferAttributes` from child according to the - // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, - // Partial -> PartialMerge, PartialMerge -> PartialMerge. - .filter(a => a.mode == Final || a.mode == PartialMerge) - .flatMap(_.aggregateFunction.inputAggBufferAttributes) + // there're exactly four cases needs `inputAggBufferAttributes` from child according to the + // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, + // Partial -> PartialMerge, PartialMerge -> PartialMerge. + .filter(a => a.mode == Final || a.mode == PartialMerge) + .flatMap(_.aggregateFunction.aggBufferAttributes) } private lazy val uniqueModes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct @@ -1519,7 +1472,7 @@ case class GpuHashAggregateExec( */ override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index a882a19dde1..412e0293bac 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -381,6 +381,7 @@ case class GpuSampleExec(lowerBound: Double, upperBound: Double, withReplacement override lazy val additionalMetrics: Map[String, GpuMetric] = Map( OP_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_OP_TIME)) + // TODO CPU vs GPU OP TIME in Debug mode??? override def output: Seq[Attribute] = { child.output @@ -419,32 +420,31 @@ case class GpuSampleExec(lowerBound: Double, upperBound: Double, withReplacement val sampler = new BernoulliCellSampler(lowerBound, upperBound) sampler.setSeed(seed + index) iterator.map[ColumnarBatch] { batch => - numOutputBatches += 1 withResource(batch) { b => // will generate new columnar column, close this - val numRows = b.numRows() - val filter = withResource(HostColumnVector.builder(DType.BOOL8, numRows)) { - builder => - (0 until numRows).foreach { _ => - val n = sampler.sample() - if (n > 0) { - builder.append(1.toByte) - numOutputRows += 1 - } else { - builder.append(0.toByte) + withResource(new NvtxWithMetrics("sample", NvtxColor.DARK_GREEN, opTime)) { _ => + val numRows = b.numRows() + val filter = withResource(HostColumnVector.builder(DType.BOOL8, numRows)) { + builder => + var i = 0 + while (i < numRows) { + i = i + 1 + builder.append(if (sampler.sample() > 0) 1.toByte else 0.toByte) } - } - builder.buildAndPutOnDevice() - } + builder.buildAndPutOnDevice() + } - // use GPU filer rows - val colTypes = GpuColumnVector.extractTypes(b) - withResource(filter) { filter => - withResource(GpuColumnVector.from(b)) { tbl => - withResource(tbl.filter(filter)) { filteredData => - if (filteredData.getRowCount == 0) { - GpuColumnVector.emptyBatchFromTypes(colTypes) - } else { - GpuColumnVector.from(filteredData, colTypes) + // use GPU filer rows + val colTypes = GpuColumnVector.extractTypes(b) + withResource(filter) { filter => + withResource(GpuColumnVector.from(b)) { tbl => + withResource(tbl.filter(filter)) { filteredData => + numOutputBatches += 1 + numOutputRows += filteredData.getRowCount + if (filteredData.getRowCount == 0) { + GpuColumnVector.emptyBatchFromTypes(colTypes) + } else { + GpuColumnVector.from(filteredData, colTypes) + } } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index c9fc473e221..4e83885cc06 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeS import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.{ArrayData, TypeUtils} -import org.apache.spark.sql.rapids.aggregate.GpuSum +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.aggregate.GpuSumDefaults import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -176,14 +177,12 @@ trait GpuAggregateFunction extends GpuExpression */ val evaluateExpression: Expression - /** Attributes of fields in aggBufferSchema. */ - def aggBufferAttributes: Seq[AttributeReference] - /** - * Result of the aggregate function when the input is empty. This is currently only used for the - * proper rewriting of distinct aggregate functions. + * This is the contract with the outside world. It describes what the output of postUpdate should + * look like, and what the input to preMerge looks like. It also describes what the output of + * postMerge must look like. */ - def defaultResult: Option[GpuLiteral] = None + def aggBufferAttributes: Seq[AttributeReference] def sql(isDistinct: Boolean): String = { val distinct = if (isDistinct) "DISTINCT " else "" @@ -196,19 +195,8 @@ trait GpuAggregateFunction extends GpuExpression prettyName + flatArguments.mkString(start, ", ", ")") } - /** - * Attributes of fields in input aggregation buffers (immutable aggregation buffers that are - * merged with mutable aggregation buffers in the merge() function or merge expressions). - * These attributes are created automatically by cloning the [[aggBufferAttributes]]. - */ - final lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - /** An aggregate function is not foldable. */ final override def foldable: Boolean = false - - /** The schema of the aggregation buffer. */ - def aggBufferSchema: StructType = null //not used in GPU version } case class WrappedAggFunction(aggregateFunction: GpuAggregateFunction, filter: Expression) @@ -575,6 +563,95 @@ case class GpuMax(child: Expression) extends GpuAggregateFunction } } +/** + * All decimal processing in Spark has overflow detection as a part of it. Either it replaces + * the value with a null in non-ANSI mode, or it throws an exception in ANSI mode. Spark will also + * do the processing for larger values as `Decimal` values which are based on `BigDecimal` and have + * unbounded precision. So in most cases it is impossible to overflow/underflow so much that an + * incorrect value is returned. Spark will just use more and more memory to hold the value and + * then check for overflow at some point when the result needs to be turned back into a 128-bit + * value. + * + * We cannot do the same thing. Instead we take three strategies to detect overflow. + * + * 1. For decimal values with a precision of 8 or under we follow Spark and do the SUM + * on the unscaled value as a long, and then bit-cast the result back to a Decimal value. + * this means that we can SUM `174,467,442,481` maximum or minimum decimal values with a + * precision of 8 before overflow can no longer be detected. It is much higher for decimal + * values with a smaller precision. + * 2. For decimal values with a precision from 9 to 28 inclusive we sum them as 128-bit values. + * this is very similar to what we do in the first strategy. The main differences are that we + * use 128-bit value when doing the sum, and we check for overflow after processing a batch. + * In the case of group-by and reduction that happens after the update stage and also after each + * merge stage. In the worst case this mens that we can SUM `24,028,236,692` maximum or minimum + * decimal values with a precision of 28 before overflow can no longer be detected. + * 3. For anything larger than precision 28 we do the same things we do for strategy 2, but we also + * take the digits above 28 and sum them separately. We then check to see if they would have + * overflowed the original limits. This lets us detect overflow in cases where the original + * value would have wrapped around. The reason this works is because we have a hard limit on the + * maximum number of values in a single batch being processed. `Int.MaxValue`, or about 2.2 + * billion values. So we use a precision on the higher values that is large enough to handle + * 2.2 billion values and still detect overflow. This equates to a precision of about 10 more + * than is needed to hold the higher digits. This effectively gives us unlimited overflow + * detection. + * + * + * + * but we do want to have some kind of guarantees that are large + * enough that users feel comfortable using our framework for doing the processing. In Spark there + * are a few optimizations around SUM where if the output fits in 64-bits, then they will use it + * to do the SUM instead of processing the values using `Decimal`. This speeds the processing up a + * lot, but also means that Spark can only detect the overflow if it does not wrap around and go + * back to being a valid number. The formula for this is. + * + * `(Long.MaxValue * 2 − DEC_18_MAX) ÷ DEC_8_MAX` = + * + * `DEC_18_MAX` is the maximum value that a `Decimal(18, 0)` can hold. This is because Spark will + * do the SUM as this type, so that it gives us a lot of values before overflow is even possible + * (10-billion) because it adds 10 to the input precision. `DEC_8_MAX` is the maximum value that + * a `Decimal(8, 0)` can hold. This means we can have at least 174 billion values before Spark + * could possibly return a bogus value or over 17 times the number of values it takes to hit the + * overflow case to being with. In practice it will likely be a lot more than that because not all + * values will be the maximum (or minimum) values allowed. + * + * We don't necessarily want to match Spark exactly. We could come close because each batch can have + * at most `Int.MaxValue` rows in it, so if we can ensure we can detect overflow on 2.2 billion + * This means that we want to try and match Spark, and ideally have something around 100-billion+ + * values before overflow is no longer detectable in the worst case. + * + * TODO look at merge vs update for this. Because for update we just need to support Int.MaxValue + * values before we can not detect overflow. For merge we are combining multiple of these results + * together. So if we can support Int.MaxValue of these values, then we know we will never be + * bitten by overflow. So if we can support 2.2 billion values. + */ +object GpuDecimalSumOverflow { + /** + * The increase in precision for the output of a SUM from the input. This is hard coded by + * Spark so we just have it here. This means that for most types without being limited to + * a precision of 38 you get 10-billion+ values before an overflow would even be possible. + */ + val sumPrecisionIncrease: Int = 10 + + /** + * Generally we want a guarantee that is at least 10x larger than the original overflow. + */ + val extraGuaranteePrecision: Int = 1 + + /** + * The precision above which we need extra overflow checks while doing an update. This is because + * anything above this precision could in theory overflow beyond detection within a single input + * batch. + */ + val updateCutoffPrecision: Int = 28 + + /** + * This is the precision above which we need to do extra checks for overflow when merging + * results. This is because anything above this precision could in theory overflow a decimal128 + * value beyond detection in a batch of already updated and checked values. + */ + val mergeCutoffPrecision: Int = 20 +} + /** * This is equivalent to what Spark does after a sum to check for overflow * ` @@ -627,15 +704,122 @@ case class GpuCheckOverflowAfterSum( override def children: Seq[Expression] = Seq(data, isEmpty) } -trait GpuSumBase extends GpuAggregateFunction with ImplicitCastInputTypes - with GpuBatchedRunningWindowWithFixer - with GpuAggregateWindowFunction - with GpuRunningWindowFunction { +/** + * This extracts the highest digits from a Decimal value as a part of doing a SUM. + */ +case class GpuDecimalSumHighDigits( + input: Expression, + originalInputType: DecimalType) extends GpuExpression with ShimExpression { + + override def nullable: Boolean = input.nullable + + override def toString: String = s"GpuDecimalSumHighDigits($input)" + + override def sql: String = input.sql + + override val dataType: DecimalType = DecimalType(originalInputType.precision + + GpuDecimalSumOverflow.sumPrecisionIncrease + GpuDecimalSumOverflow.extraGuaranteePrecision - + GpuDecimalSumOverflow.updateCutoffPrecision, 0) + // Marking these as lazy because they are not serializable + private lazy val outputDType = GpuColumnVector.getNonNestedRapidsType(dataType) + private lazy val intermediateDType = + DType.create(DType.DTypeEnum.DECIMAL128, outputDType.getScale) + + private lazy val divisionFactor: Decimal = + Decimal(math.pow(10, GpuDecimalSumOverflow.updateCutoffPrecision)) + private val divisionType = DecimalType(38, 0) + + override def columnarEval(batch: ColumnarBatch): Any = { + withResource(GpuProjectExec.projectSingle(batch, input)) { inputCol => + val inputBase = inputCol.getBase + // We don't have direct access to 128 bit ints so we use a decimal with a scale of 0 + // as a stand in. + val bitCastInputType = DType.create(DType.DTypeEnum.DECIMAL128, 0) + val divided = withResource(inputBase.bitCastTo(bitCastInputType)) { bitCastInput => + withResource(GpuScalar.from(divisionFactor, divisionType)) { divisor => + bitCastInput.div(divisor, intermediateDType) + } + } + val ret = withResource(divided) { divided => + if (divided.getType.equals(outputDType)) { + divided.incRefCount() + } else { + divided.castTo(outputDType) + } + } + GpuColumnVector.from(ret, dataType) + } + } + + override def children: Seq[Expression] = Seq(input) +} + +/** + * Return a boolean if this decimal overflowed or not + */ +case class GpuDecimalDidOverflow( + data: Expression, + rangeType: DecimalType, + nullOnOverflow: Boolean) extends GpuExpression with ShimExpression { - val child: Expression - val resultType: DataType - val failOnErrorOverride: Boolean - val extraDecimalOverflowChecks: Boolean + override def nullable: Boolean = true + + override def toString: String = + s"GpuDecimalDidOverflow($data, $rangeType, $nullOnOverflow)" + + override def sql: String = data.sql + + override def dataType: DataType = BooleanType + + override def columnarEval(batch: ColumnarBatch): Any = { + withResource(GpuProjectExec.projectSingle(batch, data)) { dataCol => + val dataBase = dataCol.getBase + withResource(DecimalUtil.outOfBounds(dataBase, rangeType)) { outOfBounds => + if (!nullOnOverflow) { + withResource(outOfBounds.any()) { isAny => + if (isAny.isValid && isAny.getBoolean) { + throw new ArithmeticException("Overflow as a part of SUM") + } + } + } else { + GpuColumnVector.from(outOfBounds.incRefCount(), dataType) + } + } + } + } + + override def children: Seq[Expression] = Seq(data) +} + +case class GpuSum(child: Expression, + resultType: DataType, + failOnErrorOverride: Boolean = SQLConf.get.ansiEnabled, + forceWindowSumToNotBeReplaced: Boolean = false) + extends GpuAggregateFunction with ImplicitCastInputTypes + with GpuReplaceWindowFunction + with GpuBatchedRunningWindowWithFixer + with GpuAggregateWindowFunction + with GpuRunningWindowFunction { + + private lazy val childIsDecimal: Boolean = + child.dataType.isInstanceOf[DecimalType] + + private lazy val childDecimalType: DecimalType = + child.dataType.asInstanceOf[DecimalType] + + private lazy val needsDec128MergeOverflowChecks: Boolean = + childIsDecimal && childDecimalType.precision > GpuDecimalSumOverflow.mergeCutoffPrecision + + private lazy val needsDec128UpdateOverflowChecks: Boolean = + childIsDecimal && + childDecimalType.precision > GpuDecimalSumOverflow.updateCutoffPrecision + + // For some operations we need to SUm the higher digits in addition to the regular value so + // we can detect overflow. This is the type of the higher digits SUM value. + private lazy val higherDigitsCheckType: DecimalType = { + val t = resultType.asInstanceOf[DecimalType] + DecimalType(t.precision - GpuDecimalSumOverflow.updateCutoffPrecision, 0) + } private lazy val zeroDec = { val dt = resultType.asInstanceOf[DecimalType] @@ -643,74 +827,198 @@ trait GpuSumBase extends GpuAggregateFunction with ImplicitCastInputTypes } override lazy val initialValues: Seq[GpuLiteral] = resultType match { - case _: DecimalType if extraDecimalOverflowChecks => + case _: DecimalType if GpuSumDefaults.hasIsEmptyField => Seq(zeroDec, GpuLiteral(true, BooleanType)) case _ => Seq(GpuLiteral(null, resultType)) } + private lazy val updateHigherOrderBits = { + val input = if (child.dataType != resultType) { + GpuCast(child, resultType) + } else { + child + } + GpuDecimalSumHighDigits(input, childDecimalType) + } + // we need to cast to `resultType` here, since Spark is not widening types // as done before Spark 3.2.0. See CudfSum for more info. override lazy val inputProjection: Seq[Expression] = resultType match { - case _: DecimalType if extraDecimalOverflowChecks => - // Spark tracks null columns through a second column isEmpty for decimal. - Seq(GpuIf(GpuIsNull(child), zeroDec, GpuCast(child, resultType)), GpuIsNull(child)) + case _: DecimalType => + // Decimal is complicated... + if (GpuSumDefaults.hasIsEmptyField) { + // Spark tracks null columns through a second column isEmpty for decimal. So null values + // are replaced with 0, and a separate boolean column for isNull is added + if (needsDec128UpdateOverflowChecks) { + // If we want extra checks for overflow, then we also want to include it here + Seq(GpuIf(GpuIsNull(child), zeroDec, GpuCast(child, resultType)), + GpuIsNull(child), + updateHigherOrderBits) + } else { + Seq(GpuIf(GpuIsNull(child), zeroDec, GpuCast(child, resultType)), GpuIsNull(child)) + } + } else { + if (needsDec128UpdateOverflowChecks) { + // If we want extra checks for overflow, then we also want to include it here + Seq(GpuCast(child, resultType), updateHigherOrderBits) + } else { + Seq(GpuCast(child, resultType)) + } + } case _ => Seq(GpuCast(child, resultType)) } private lazy val updateSum = new CudfSum(resultType) private lazy val updateIsEmpty = new CudfMin(BooleanType) + private lazy val updateOverflow = new CudfSum(updateHigherOrderBits.dataType) override lazy val updateAggregates: Seq[CudfAggregate] = resultType match { - case _: DecimalType if extraDecimalOverflowChecks => - Seq(updateSum, updateIsEmpty) + case _: DecimalType => + if (GpuSumDefaults.hasIsEmptyField) { + if (needsDec128UpdateOverflowChecks) { + Seq(updateSum, updateIsEmpty, updateOverflow) + } else { + Seq(updateSum, updateIsEmpty) + } + } else { + if (needsDec128UpdateOverflowChecks) { + Seq(updateSum, updateOverflow) + } else { + Seq(updateSum) + } + } case _ => Seq(updateSum) } + private[this] def extendedPostUpdateDecOverflowCheck(dt: DecimalType) = + GpuCheckOverflow( + GpuIf( + GpuDecimalDidOverflow(updateOverflow.attr, + higherDigitsCheckType, + !failOnErrorOverride), + GpuLiteral(null, dt), + updateSum.attr), + dt, !failOnErrorOverride) + override lazy val postUpdate: Seq[Expression] = resultType match { - case dt: DecimalType if extraDecimalOverflowChecks => - Seq(GpuCheckOverflow(updateSum.attr, dt, !failOnErrorOverride), updateIsEmpty.attr) + case dt: DecimalType => + if (GpuSumDefaults.hasIsEmptyField) { + if (needsDec128UpdateOverflowChecks) { + Seq(extendedPostUpdateDecOverflowCheck(dt), updateIsEmpty.attr) + } else { + Seq(GpuCheckOverflow(updateSum.attr, dt, !failOnErrorOverride), updateIsEmpty.attr) + } + } else { + if (needsDec128UpdateOverflowChecks) { + Seq(extendedPostUpdateDecOverflowCheck(dt)) + } else { + postUpdateAttr + } + } case _ => postUpdateAttr } // output of GpuSum private lazy val sum = AttributeReference("sum", resultType)() + // Used for Decimal overflow detection private lazy val isEmpty = AttributeReference("isEmpty", BooleanType)() override lazy val aggBufferAttributes: Seq[AttributeReference] = resultType match { - case _: DecimalType if extraDecimalOverflowChecks => + case _: DecimalType if GpuSumDefaults.hasIsEmptyField => sum :: isEmpty :: Nil case _ => sum :: Nil } + private lazy val mergeHigherOrderBits = GpuDecimalSumHighDigits(sum, childDecimalType) + override lazy val preMerge: Seq[Expression] = resultType match { - case _: DecimalType if extraDecimalOverflowChecks => - Seq(sum, isEmpty, GpuIsNull(sum)) + case _: DecimalType => + if (GpuSumDefaults.hasIsEmptyField) { + if (needsDec128MergeOverflowChecks) { + Seq(sum, isEmpty, GpuIsNull(sum), mergeHigherOrderBits) + } else { + Seq(sum, isEmpty, GpuIsNull(sum)) + } + } else { + if (needsDec128MergeOverflowChecks) { + Seq(sum, mergeHigherOrderBits) + } else { + aggBufferAttributes + } + } case _ => aggBufferAttributes } private lazy val mergeSum = new CudfSum(resultType) private lazy val mergeIsEmpty = new CudfMin(BooleanType) private lazy val mergeIsOverflow = new CudfMax(BooleanType) + private lazy val mergeOverflow = new CudfSum(mergeHigherOrderBits.dataType) // To be able to do decimal overflow detection, we need a CudfSum that does **not** ignore nulls. // Cudf does not have such an aggregation, so for merge we have to work around that similar to // what happens with isEmpty override lazy val mergeAggregates: Seq[CudfAggregate] = resultType match { - case _: DecimalType if extraDecimalOverflowChecks => - Seq(mergeSum, mergeIsEmpty, mergeIsOverflow) + case _: DecimalType => + if (GpuSumDefaults.hasIsEmptyField) { + if (needsDec128MergeOverflowChecks) { + Seq(mergeSum, mergeIsEmpty, mergeIsOverflow, mergeOverflow) + } else { + Seq(mergeSum, mergeIsEmpty, mergeIsOverflow) + } + } else { + if (needsDec128MergeOverflowChecks) { + Seq(mergeSum, mergeOverflow) + } else { + Seq(mergeSum) + } + } case _ => Seq(mergeSum) } override lazy val postMerge: Seq[Expression] = resultType match { - case _: DecimalType if extraDecimalOverflowChecks => - Seq(GpuIf(mergeIsOverflow.attr, GpuLiteral.create(null, resultType), mergeSum.attr), - mergeIsEmpty.attr) + case dt: DecimalType => + if (GpuSumDefaults.hasIsEmptyField) { + if (needsDec128MergeOverflowChecks) { + Seq( + GpuCheckOverflow( + GpuIf( + GpuOr( + GpuDecimalDidOverflow(mergeOverflow.attr, higherDigitsCheckType, + !failOnErrorOverride), + mergeIsOverflow.attr), + GpuLiteral.create(null, resultType), + mergeSum.attr), + dt, !failOnErrorOverride), + mergeIsEmpty.attr) + } else { + Seq( + GpuCheckOverflow(GpuIf(mergeIsOverflow.attr, + GpuLiteral.create(null, resultType), + mergeSum.attr), + dt, !failOnErrorOverride), + mergeIsEmpty.attr) + } + } else { + if (needsDec128MergeOverflowChecks) { + Seq( + GpuCheckOverflow( + GpuIf( + GpuDecimalDidOverflow(mergeOverflow.attr, higherDigitsCheckType, + !failOnErrorOverride), + GpuLiteral.create(null, resultType), + mergeSum.attr), + dt, !failOnErrorOverride)) + } else { + postMergeAttr + } + } + case _ => postMergeAttr } override lazy val evaluateExpression: Expression = resultType match { case dt: DecimalType => - if (extraDecimalOverflowChecks) { + if (GpuSumDefaults.hasIsEmptyField) { GpuCheckOverflowAfterSum(sum, isEmpty, dt, !failOnErrorOverride) } else { GpuCheckOverflow(sum, dt, !failOnErrorOverride) @@ -726,6 +1034,32 @@ trait GpuSumBase extends GpuAggregateFunction with ImplicitCastInputTypes override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function gpu sum") + // Replacement Window Function + override def shouldReplaceWindow(spec: GpuWindowSpecDefinition): Boolean = { + // We only will replace this if we think an update will fail. In the cases where we can + // handle a window function larger than a single batch, we already have merge overflow + // detection enabled. + !forceWindowSumToNotBeReplaced && needsDec128UpdateOverflowChecks + } + + override def windowReplacement(spec: GpuWindowSpecDefinition): Expression = { + // We need extra overflow checks for some larger decimal type. To do these checks we + // extract the higher digits and SUM them separately to see if they would overflow. + // If they do we know that the regular SUM also overflowed. If not we know that we can rely on + // the existing overflow code to detect it. + val regularSum = GpuWindowExpression( + GpuSum(child, resultType, failOnErrorOverride = failOnErrorOverride, + forceWindowSumToNotBeReplaced = true), + spec) + val highOrderDigitsSum = GpuWindowExpression( + GpuSum( + GpuDecimalSumHighDigits(GpuCast(child, resultType), childDecimalType), + higherDigitsCheckType, + failOnErrorOverride = failOnErrorOverride), + spec) + GpuIf(GpuIsNull(highOrderDigitsSum), GpuLiteral(null, resultType), regularSum) + } + // GENERAL WINDOW FUNCTION // Spark 3.2.0+ stopped casting the input data to the output type before the sum operation // This fixes that. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala index a349c875965..86de80c5703 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala @@ -144,17 +144,8 @@ abstract class CudfBinaryArithmetic extends CudfBinaryOperator with NullIntolera override def dataType: DataType = left.dataType } -case class GpuAdd( - left: Expression, - right: Expression, - failOnError: Boolean) extends CudfBinaryArithmetic { - override def inputType: AbstractDataType = TypeCollection.NumericAndInterval - - override def symbol: String = "+" - - override def binaryOp: BinaryOp = BinaryOp.ADD - - private[this] def basicOpOverflowCheck( +object GpuAdd extends Arm { + def basicOpOverflowCheck( lhs: BinaryOperable, rhs: BinaryOperable, ret: ColumnVector): Unit = { @@ -179,7 +170,7 @@ case class GpuAdd( } } - private[this] def decimalOpOverflowCheck( + def didDecimalOverflow( lhs: BinaryOperable, rhs: BinaryOperable, ret: ColumnVector): ColumnVector = { @@ -189,7 +180,7 @@ case class GpuAdd( // the result val numRows = ret.getRowCount.toInt val zero = BigDecimal(0) - val overflow = withResource(DecimalUtil.lessThan(rhs, zero, numRows)) { rhsLz => + withResource(DecimalUtil.lessThan(rhs, zero, numRows)) { rhsLz => val argsSignSame = withResource(DecimalUtil.lessThan(lhs, zero, numRows)) { lhsLz => lhsLz.equalTo(rhsLz) } @@ -203,7 +194,14 @@ case class GpuAdd( } } } - withResource(overflow) { overflow => + } + + def decimalOpOverflowCheck( + lhs: BinaryOperable, + rhs: BinaryOperable, + ret: ColumnVector, + failOnError: Boolean): ColumnVector = { + withResource(didDecimalOverflow(lhs, rhs, ret)) { overflow => if (failOnError) { withResource(overflow.any()) { any => if (any.isValid && any.getBoolean) { @@ -212,23 +210,34 @@ case class GpuAdd( } ret.incRefCount() } else { - withResource(GpuScalar.from(null, dataType)) { nullVal => + withResource(Scalar.fromNull(ret.getType)) { nullVal => overflow.ifElse(nullVal, ret) } } } } +} + +case class GpuAdd( + left: Expression, + right: Expression, + failOnError: Boolean) extends CudfBinaryArithmetic { + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval + + override def symbol: String = "+" + + override def binaryOp: BinaryOp = BinaryOp.ADD override def doColumnar(lhs: BinaryOperable, rhs: BinaryOperable): ColumnVector = { val ret = super.doColumnar(lhs, rhs) withResource(ret) { ret => // No shims are needed, because it actually supports ANSI mode from Spark v3.0.1. if (failOnError && GpuAnsi.needBasicOpOverflowCheck(dataType)) { - basicOpOverflowCheck(lhs, rhs, ret) + GpuAdd.basicOpOverflowCheck(lhs, rhs, ret) } if (dataType.isInstanceOf[DecimalType]) { - decimalOpOverflowCheck(lhs, rhs, ret) + GpuAdd.decimalOpOverflowCheck(lhs, rhs, ret, failOnError) } else { ret.incRefCount() } From 7bcb853fd674c534b4eadc5da070defcca677352 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 2 Dec 2021 14:32:50 -0600 Subject: [PATCH 2/5] Addressed review comments --- docs/compatibility.md | 2 +- .../nvidia/spark/rapids/GpuWindowExec.scala | 2 +- .../spark/rapids/basicPhysicalOperators.scala | 1 - .../spark/sql/rapids/AggregateFunctions.scala | 48 ++++++------------- 4 files changed, 16 insertions(+), 37 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index a3e9862588e..f822e0aa97d 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -120,7 +120,7 @@ GPU. If this concerns you, you should upgrade to Spark 3.1.0 or above. When Apache Spark does a sum aggregation on decimal values it will store the result in a value with a precision that is the input precision + 10, but with a maximum precision of 38. -For an input precision of 9 and above, Spark will do the aggregations as a java `BigDecimal` +For an input precision of 9 and above, Spark will do the aggregations as a Java `BigDecimal` value which is slow, but guarantees that any overflow can be detected because it can work with effectively unlimited precision. For inputs with a precision of 8 or below Spark will internally do the calculations as a long value, 64-bits. When the precision is 8, you would need at least diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala index fa873dbaac8..3fef0aac5ed 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala @@ -328,7 +328,7 @@ object GpuWindowExec extends Arm { // by our GPU window operations anyways. rep.windowReplacement(spec) case GpuWindowExpression(rep: GpuReplaceWindowFunction, spec) - if rep.shouldReplaceWindow(spec)=> + if rep.shouldReplaceWindow(spec) => rep.windowReplacement(spec) } // Second pass looks for GpuWindowFunctions and GpuWindowSpecDefinitions to build up diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index 412e0293bac..fc754ced86a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -381,7 +381,6 @@ case class GpuSampleExec(lowerBound: Double, upperBound: Double, withReplacement override lazy val additionalMetrics: Map[String, GpuMetric] = Map( OP_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_OP_TIME)) - // TODO CPU vs GPU OP TIME in Debug mode??? override def output: Seq[Attribute] = { child.output diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index 4e83885cc06..55953cc1b7c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -579,14 +579,16 @@ case class GpuMax(child: Expression) extends GpuAggregateFunction * this means that we can SUM `174,467,442,481` maximum or minimum decimal values with a * precision of 8 before overflow can no longer be detected. It is much higher for decimal * values with a smaller precision. - * 2. For decimal values with a precision from 9 to 28 inclusive we sum them as 128-bit values. + * 2. For decimal values with a precision from 9 to 20 inclusive we sum them as 128-bit values. * this is very similar to what we do in the first strategy. The main differences are that we - * use 128-bit value when doing the sum, and we check for overflow after processing a batch. + * use a 128-bit value when doing the sum, and we check for overflow after processing each batch. * In the case of group-by and reduction that happens after the update stage and also after each - * merge stage. In the worst case this mens that we can SUM `24,028,236,692` maximum or minimum - * decimal values with a precision of 28 before overflow can no longer be detected. - * 3. For anything larger than precision 28 we do the same things we do for strategy 2, but we also - * take the digits above 28 and sum them separately. We then check to see if they would have + * merge stage. This gives us enough room that we can always detect overflow when summing a + * single batch. Even on a merge where we could be doing the aggregation on a batch that has + * all max output values in it. + * 3. For values from 21 to 28 inclusive we have enough room to not check for overflow on teh update + * aggregation, but for the merge aggregation we need to do some extra checks. This is done by + * taking the digits above 28 and sum them separately. We then check to see if they would have * overflowed the original limits. This lets us detect overflow in cases where the original * value would have wrapped around. The reason this works is because we have a hard limit on the * maximum number of values in a single batch being processed. `Int.MaxValue`, or about 2.2 @@ -594,35 +596,13 @@ case class GpuMax(child: Expression) extends GpuAggregateFunction * 2.2 billion values and still detect overflow. This equates to a precision of about 10 more * than is needed to hold the higher digits. This effectively gives us unlimited overflow * detection. + * 4. For anything larger than precision 28 we do the same overflow detection for strategy 3, but + * also do it on the update aggregation. This lets us fully detect overflows in any stage of + * an aggregation. * - * - * - * but we do want to have some kind of guarantees that are large - * enough that users feel comfortable using our framework for doing the processing. In Spark there - * are a few optimizations around SUM where if the output fits in 64-bits, then they will use it - * to do the SUM instead of processing the values using `Decimal`. This speeds the processing up a - * lot, but also means that Spark can only detect the overflow if it does not wrap around and go - * back to being a valid number. The formula for this is. - * - * `(Long.MaxValue * 2 − DEC_18_MAX) ÷ DEC_8_MAX` = - * - * `DEC_18_MAX` is the maximum value that a `Decimal(18, 0)` can hold. This is because Spark will - * do the SUM as this type, so that it gives us a lot of values before overflow is even possible - * (10-billion) because it adds 10 to the input precision. `DEC_8_MAX` is the maximum value that - * a `Decimal(8, 0)` can hold. This means we can have at least 174 billion values before Spark - * could possibly return a bogus value or over 17 times the number of values it takes to hit the - * overflow case to being with. In practice it will likely be a lot more than that because not all - * values will be the maximum (or minimum) values allowed. - * - * We don't necessarily want to match Spark exactly. We could come close because each batch can have - * at most `Int.MaxValue` rows in it, so if we can ensure we can detect overflow on 2.2 billion - * This means that we want to try and match Spark, and ideally have something around 100-billion+ - * values before overflow is no longer detectable in the worst case. - * - * TODO look at merge vs update for this. Because for update we just need to support Int.MaxValue - * values before we can not detect overflow. For merge we are combining multiple of these results - * together. So if we can support Int.MaxValue of these values, then we know we will never be - * bitten by overflow. So if we can support 2.2 billion values. + * Note that for Window operations either there is no merge stage or it only has a single value + * being merged into a batch instead of an entire batch being merged together. This lets us handle + * the overflow detection with what is built into GpuAdd. */ object GpuDecimalSumOverflow { /** From a0ef068d54b164cee91aa06439406a5ec9487a98 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 3 Dec 2021 14:27:10 -0600 Subject: [PATCH 3/5] Review comments --- .../src/main/scala/com/nvidia/spark/rapids/aggregate.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index 7ccb95ee267..58896717179 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -691,8 +691,7 @@ class GpuHashAggregateIterator( /** * Apply the "pre" step: preMerge for merge, or pass-through in the update case * @param toAggregateBatch - input (to the agg) batch from the child directly in the - * merge - * case, or from the `inputProjection` in the update case. + * merge case, or from the `inputProjection` in the update case. * @return a pre-processed batch that can be later cuDF aggregated */ def preProcess(toAggregateBatch: ColumnarBatch): ColumnarBatch = { From bf020c7dac3f4110b48f830c2cdb7b5054c0ce91 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Mon, 6 Dec 2021 10:05:49 -0600 Subject: [PATCH 4/5] Fix memory issue with overflow test --- integration_tests/src/main/python/hash_aggregate_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index d35035ae664..a47dcff2fd6 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -310,7 +310,12 @@ def test_hash_reduction_decimal_overflow_sum(precision): assert_gpu_and_cpu_are_equal_collect( lambda spark: spark.range(count)\ .selectExpr("CAST('{}' as Decimal({}, 0)) as a".format(constant, precision))\ - .selectExpr("SUM(a)")) + .selectExpr("SUM(a)"), + # This is set to 128m becuase of a number of other bugs that compond to having us + # run out of memory in some setups. These should not happen in production, becasue + # we really are just doing a really bad job at multiplying to get this result so + # some optimizations are conspiring against us. + conf = {'spark.rapids.sql.batchSizeBytes': '128m'}) @pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn) def test_hash_grpby_sum_count_action(data_gen): From e68b833068f9955206b2aac404c34fab94ae03fe Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Mon, 6 Dec 2021 10:50:00 -0600 Subject: [PATCH 5/5] Update integration_tests/src/main/python/hash_aggregate_test.py Co-authored-by: Jason Lowe --- integration_tests/src/main/python/hash_aggregate_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index a47dcff2fd6..406ec0fb694 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -311,8 +311,8 @@ def test_hash_reduction_decimal_overflow_sum(precision): lambda spark: spark.range(count)\ .selectExpr("CAST('{}' as Decimal({}, 0)) as a".format(constant, precision))\ .selectExpr("SUM(a)"), - # This is set to 128m becuase of a number of other bugs that compond to having us - # run out of memory in some setups. These should not happen in production, becasue + # This is set to 128m because of a number of other bugs that compound to having us + # run out of memory in some setups. These should not happen in production, because # we really are just doing a really bad job at multiplying to get this result so # some optimizations are conspiring against us. conf = {'spark.rapids.sql.batchSizeBytes': '128m'})