Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full support for SUM overflow detection on decimal [databricks] #4272

Merged
merged 6 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 25 additions & 58 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,63 +109,31 @@ a few operations that we cannot support to the same degree as Spark can on the C

### Decimal Sum Aggregation

When Apache Spark does a sum aggregation on decimal values it will store the result in a value
with a precision that is the input precision + 10, but with a maximum precision of 38. The table
below shows the number of rows/values in an aggregation before an overflow is possible,
and the number of rows/values in the aggregation before an overflow might not be detected.
The numbers are for Spark 3.1.0 and above after a number of fixes were put in place, please see
A number of fixes for overflow detection went into Spark 3.1.0. Please see
[SPARK-28067](https://issues.apache.org/jira/browse/SPARK-28067) and
[SPARK-32018](https://issues.apache.org/jira/browse/SPARK-32018) for more information.
Please also note that these are for the worst case situations, meaning all the values in the sum
were either the largest or smallest values possible to be stored in the input type. In the common
case, where the numbers are smaller, or vary between positive and negative values, many more
rows/values can be processed without any issues.

|Input Precision|Number of values before overflow is possible|Maximum number of values for guaranteed overflow detection (Spark CPU)|Maximum number of values for guaranteed overflow detection (RAPIDS GPU)|
|---------------|------------------------------|------------|-------------|
|1 |11,111,111,111 |2,049,638,219,301,061,290 |Same as CPU |
|2 |10,101,010,101 |186,330,738,118,278,299 |Same as CPU |
|3 |10,010,010,010 |18,465,199,272,982,534 |Same as CPU |
|4 |10,001,000,100 |1,844,848,892,260,181 |Same as CPU |
|5 |10,000,100,001 |184,459,285,329,948 |Same as CPU |
|6 |10,000,010,000 |18,436,762,510,472 |Same as CPU |
|7 |10,000,001,000 |1,834,674,590,838 |Same as CPU |
|8 |10,000,000,100 |174,467,442,481 |Same as CPU |
|9 |10,000,000,010 |Unlimited |Unlimited |
|10 - 19 |10,000,000,000 |Unlimited |Unlimited |
|20 |10,000,000,000 |Unlimited |3,402,823,659,209,384,634 |
|21 |10,000,000,000 |Unlimited |340,282,356,920,938,463 |
|22 |10,000,000,000 |Unlimited |34,028,226,692,093,846 |
|23 |10,000,000,000 |Unlimited |3,402,813,669,209,384 |
|24 |10,000,000,000 |Unlimited |340,272,366,920,938 |
|25 |10,000,000,000 |Unlimited |34,018,236,692,093 |
|26 |10,000,000,000 |Unlimited |3,392,823,669,209 |
|27 |10,000,000,000 |Unlimited |330,282,366,920 |
|28 |10,000,000,000 |Unlimited |24,028,236,692 |
|29 |1,000,000,000 |Unlimited |Falls back to CPU |
|30 |100,000,000 |Unlimited |Falls back to CPU |
|31 |10,000,000 |Unlimited |Falls back to CPU |
|32 |1,000,000 |Unlimited |Falls back to CPU |
|33 |100,000 |Unlimited |Falls back to CPU |
|34 |10,00 |Unlimited |Falls back to CPU |
|35 |1,000 |Unlimited |Falls back to CPU |
|36 |100 |Unlimited |Falls back to CPU |
|37 |10 |Unlimited |Falls back to CPU |
|38 |1 |Unlimited |Falls back to CPU |

For an input precision of 9 and above, Spark will do the aggregations as a `BigDecimal`
value which is slow, but guarantees that any overflow can be detected. For inputs with a
precision of 8 or below Spark will internally do the calculations as a long value, 64-bits.
When the precision is 8, you would need at least 174-billion values/rows contributing to a
single aggregation result, and even then all the values would need to be either the largest
or the smallest value possible to be stored in the type before the overflow is no longer detected.

For the RAPIDS Accelerator we only have access to at most a 128-bit value to store the results
in and still detect overflow. Because of this we cannot guarantee overflow detection in all
cases. In some cases we can guarantee unlimited overflow detection because of the maximum number of
values that RAPIDS will aggregate in a single batch. But even in the worst cast for a decimal value
with a precision of 28 the user would still have to aggregate so many values that it overflows 2.4
times over before we are no longer able to detect it.
[SPARK-32018](https://issues.apache.org/jira/browse/SPARK-32018) for more detailed information.
Some of these fixes we were able to back port, but some of them require Spark 3.1.0 or above to
fully be able to detect overflow in all cases. As such on versions of Spark older than 3.1.0 for
large decimal values there is the possibility of data corruption in some corner cases.
This is true for both the CPU and GPU implementations, but there are fewer of these cases for the
GPU. If this concerns you, you should upgrade to Spark 3.1.0 or above.

When Apache Spark does a sum aggregation on decimal values it will store the result in a value
with a precision that is the input precision + 10, but with a maximum precision of 38.
For an input precision of 9 and above, Spark will do the aggregations as a Java `BigDecimal`
value which is slow, but guarantees that any overflow can be detected because it can work with
effectively unlimited precision. For inputs with a precision of 8 or below Spark will internally do
the calculations as a long value, 64-bits. When the precision is 8, you would need at least
174,467,442,482 values/rows contributing to a single aggregation result before the overflow is no
longer detected. Even then all the values would need to be either the largest or the smallest value
possible to be stored in the type for the overflow to cause data corruption.

For the RAPIDS Accelerator we don't have direct access to unlimited precision for our calculations
like the CPU does. For input values with a precision of 8 and below we follow Spark and process the
data the same way, as a 64-bit value. For larger values we will do extra calculations looking at the
higher order digits to be able to detect overflow in all cases. But because of this you may see
some performance differences depending on the input precision used. The differences will show up
when going from an input precision of 8 to 9 and again when going from an input precision of 28 to 29.

### Decimal Average

Expand All @@ -175,8 +143,7 @@ have. It also inherits some issues from Spark itself. See
https://issues.apache.org/jira/browse/SPARK-37024 for a detailed description of some issues
with average in Spark.

In order to be able to guarantee overflow detection on the sum with at least 100-billion values
and to be able to guarantee doing the divide with half up rounding at the end we only support
In order to be able to guarantee doing the divide with half up rounding at the end we only support
average on input values with a precision of 23 or below. This is 38 - 10 for the sum guarantees
and then 5 less to be able to shift the left-hand side of the divide enough to get a correct
answer that can be rounded to the result that Spark would produce.
Expand Down
8 changes: 4 additions & 4 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,13 @@ Accelerator supports are described below.
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td><em>PS<br/>max DECIMAL precision of 18</em></td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>max child DECIMAL precision of 18;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
82 changes: 65 additions & 17 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,33 +265,59 @@ def get_params(init_list, marked_params=[]):
('b', decimal_gen_20_2),
('c', decimal_gen_20_2)]

# NOTE on older versions of Spark decimal 38 causes the CPU to crash
# instead of detect overflows, we have versions of this for both
# 36 and 38 so we can get some coverage for old versions and full
# coverage for newer versions
_grpkey_short_very_big_decimals = [
('a', RepeatSeqGen(short_gen, length=50)),
('b', decimal_gen_36_5),
('c', decimal_gen_36_5)]

_grpkey_short_very_big_neg_scale_decimals = [
('a', RepeatSeqGen(short_gen, length=50)),
('b', decimal_gen_36_neg5),
('c', decimal_gen_36_neg5)]

_grpkey_short_full_decimals = [
('a', RepeatSeqGen(short_gen, length=50)),
('b', decimal_gen_38_0),
('c', decimal_gen_38_0)]

_grpkey_short_full_neg_scale_decimals = [
('a', RepeatSeqGen(short_gen, length=50)),
('b', decimal_gen_38_neg10),
('c', decimal_gen_38_neg10)]


_init_list_no_nans_with_decimal = _init_list_no_nans + [
_grpkey_small_decimals]

_init_list_no_nans_with_decimalbig = _init_list_no_nans + [
_grpkey_small_decimals, _grpkey_short_mid_decimals, _grpkey_short_big_decimals]
_grpkey_small_decimals, _grpkey_short_mid_decimals,
_grpkey_short_big_decimals, _grpkey_short_very_big_decimals,
_grpkey_short_very_big_neg_scale_decimals]


_init_list_full_decimal = [_grpkey_short_full_decimals,
_grpkey_short_full_neg_scale_decimals]

#TODO when we can support sum on larger types https://github.com/NVIDIA/spark-rapids/issues/3944
# we should move to a larger type and use a smaller count so we can avoid the long CPU run
# then we should look at splitting the reduction up so we do half on the CPU and half on the GPU
# So we can test compatabiliy too (but without spending even longer computing)
def test_hash_reduction_decimal_overflow_sum():
#Any smaller precision takes way too long to process on the CPU
@pytest.mark.parametrize('precision', [38, 37, 36, 35, 34, 33, 32, 31, 30], ids=idfn)
def test_hash_reduction_decimal_overflow_sum(precision):
constant = '9' * precision
count = pow(10, 38 - precision)
assert_gpu_and_cpu_are_equal_collect(
# Spark adds +10 to precision so we need 10-billion entries before overflow is even possible.
# we use 10-billion and 2 to make sure we hit the overflow.
lambda spark: spark.range(0, 10000000002, 1, 48)\
.selectExpr("CAST('9999999999' as Decimal(10, 0)) as a")\
.selectExpr("SUM(a)"),
# set the batch size small because we can have limited GPU memory and the first select
# doubles the size of the batch
conf = {'spark.rapids.sql.batchSizeBytes': '64m'})
lambda spark: spark.range(count)\
.selectExpr("CAST('{}' as Decimal({}, 0)) as a".format(constant, precision))\
.selectExpr("SUM(a)"))

@pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn)
def test_hash_grpby_sum_count_action(data_gen):
assert_gpu_and_cpu_row_counts_equal(
lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b'))
)

@pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn)
def test_hash_reduction_sum_count_action(data_gen):
assert_gpu_and_cpu_row_counts_equal(
Expand All @@ -307,19 +333,41 @@ def test_hash_reduction_sum_count_action(data_gen):
def test_hash_grpby_sum(data_gen, conf):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b')),
conf=conf
)
conf = copy_and_update(allow_negative_scale_of_decimal_conf, conf))

@pytest.mark.skipif(is_before_spark_311(), reason="SUM overflows for CPU were fixed in Spark 3.1.1")
@shuffle_test
@approximate_float
@ignore_order
@incompat
@pytest.mark.parametrize('data_gen', numeric_gens + decimal_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', _init_list_full_decimal, ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
def test_hash_grpby_sum_full_decimal(data_gen, conf):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b')),
conf = copy_and_update(allow_negative_scale_of_decimal_conf, conf))

@approximate_float
@ignore_order
@incompat
@pytest.mark.parametrize('data_gen', numeric_gens + decimal_gens + [decimal_gen_20_2, decimal_gen_30_2, decimal_gen_36_5, decimal_gen_36_neg5], ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs_with_nans, params_markers_for_confs_nans), ids=idfn)
def test_hash_reduction_sum(data_gen, conf):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen, length=100).selectExpr("SUM(a)"),
conf = copy_and_update(allow_negative_scale_of_decimal_conf, conf))

@pytest.mark.skipif(is_before_spark_311(), reason="SUM overflows for CPU were fixed in Spark 3.1.1")
@approximate_float
@ignore_order
@incompat
@pytest.mark.parametrize('data_gen', numeric_gens + decimal_gens + [decimal_gen_38_0, decimal_gen_38_10, decimal_gen_38_neg10], ids=idfn)
@pytest.mark.parametrize('conf', get_params(_confs_with_nans, params_markers_for_confs_nans), ids=idfn)
def test_hash_reduction_sum_full_decimal(data_gen, conf):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen, length=100).selectExpr("SUM(a)"),
conf = copy_and_update(allow_negative_scale_of_decimal_conf, conf))

@approximate_float
@ignore_order
@incompat
Expand Down
8 changes: 4 additions & 4 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_decimal128_count_window_no_part(data_gen):
conf = allow_negative_scale_of_decimal_conf)

@ignore_order
@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', decimal_gens + decimal_128_gens, ids=idfn)
def test_decimal_sum_window(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark: three_col_df(spark, byte_gen, LongRangeGen(), data_gen),
Expand All @@ -177,7 +177,7 @@ def test_decimal_sum_window(data_gen):
conf = allow_negative_scale_of_decimal_conf)

@ignore_order
@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', decimal_gens + decimal_128_gens, ids=idfn)
def test_decimal_sum_window_no_part(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark: two_col_df(spark, LongRangeGen(), data_gen),
Expand All @@ -191,7 +191,7 @@ def test_decimal_sum_window_no_part(data_gen):


@ignore_order
@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', decimal_gens + decimal_128_gens, ids=idfn)
def test_decimal_running_sum_window(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark: three_col_df(spark, byte_gen, LongRangeGen(), data_gen),
Expand All @@ -205,7 +205,7 @@ def test_decimal_running_sum_window(data_gen):
{'spark.rapids.sql.batchSizeBytes': '100'}))

@ignore_order
@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', decimal_gens + decimal_128_gens, ids=idfn)
def test_decimal_running_sum_window_no_part(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark: two_col_df(spark, LongRangeGen(), data_gen),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@

package org.apache.spark.sql.rapids.aggregate

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuSumBase
import org.apache.spark.sql.types.DataType

case class GpuSum(child: Expression,
resultType: DataType,
failOnErrorOverride: Boolean = SQLConf.get.ansiEnabled) extends GpuSumBase {
override val extraDecimalOverflowChecks: Boolean = false
object GpuSumDefaults {
val hasIsEmptyField: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@

package org.apache.spark.sql.rapids.aggregate

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuSumBase
import org.apache.spark.sql.types.DataType

case class GpuSum(child: Expression,
resultType: DataType,
failOnErrorOverride: Boolean = SQLConf.get.ansiEnabled) extends GpuSumBase {
override val extraDecimalOverflowChecks: Boolean = true
}
object GpuSumDefaults {
val hasIsEmptyField: Boolean = true
}
30 changes: 18 additions & 12 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1345,24 +1345,30 @@ object GpuCast extends Arm {
}
}

def fixDecimalBounds(input: ColumnView,
outOfBounds: ColumnView,
ansiMode: Boolean): ColumnVector = {
if (ansiMode) {
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
input.copyToColumnVector()
} else {
withResource(Scalar.fromNull(input.getType)) { nullVal =>
outOfBounds.ifElse(nullVal, input)
}
}
}

def checkNFixDecimalBounds(
input: ColumnView,
to: DecimalType,
ansiMode: Boolean): ColumnVector = {
assert(input.getType.isDecimalType)
withResource(DecimalUtil.outOfBounds(input, to)) { outOfBounds =>
if (ansiMode) {
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
input.copyToColumnVector()
} else {
withResource(Scalar.fromNull(input.getType)) { nullVal =>
outOfBounds.ifElse(nullVal, input)
}
}
fixDecimalBounds(input, outOfBounds, ansiMode)
}
}

Expand Down
Loading