Skip to content

Commit

Permalink
Fixes creation of invalid DecimalType in GpuDivide.tagExprForGpu (NVI…
Browse files Browse the repository at this point in the history
…DIA#1991)

* fixes NVIDIA#1984

Signed-off-by: Raza Jafri <[email protected]>

* added unit tests

Signed-off-by: Raza Jafri <[email protected]>

* make sure the exception is not AnalysisException

Signed-off-by: Raza Jafri <[email protected]>

Co-authored-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri and razajafri authored Mar 23, 2021
1 parent 48bb24d commit 244b5ca
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
6 changes: 5 additions & 1 deletion integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyspark.sql.types import *
from spark_session import with_cpu_session, with_gpu_session, with_spark_session, is_before_spark_311
import pyspark.sql.functions as f
from pyspark.sql.utils import IllegalArgumentException

decimal_gens_not_max_prec = [decimal_gen_neg_scale, decimal_gen_scale_precision,
decimal_gen_same_scale_precision, decimal_gen_64bit]
Expand Down Expand Up @@ -69,7 +70,10 @@ def test_multiplication_mixed(lhs, rhs):
f.col('a') * f.col('b')),
conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', [double_gen, decimal_gen_neg_scale, DecimalGen(6, 3), DecimalGen(5, 5), DecimalGen(6, 0)], ids=idfn)
@pytest.mark.parametrize('data_gen', [double_gen, decimal_gen_neg_scale, DecimalGen(6, 3),
DecimalGen(5, 5), DecimalGen(6, 0),
pytest.param(DecimalGen(38, 21), marks=pytest.mark.xfail(reason="The precision is too large to be supported on the GPU", raises=IllegalArgumentException)),
pytest.param(DecimalGen(21, 17), marks=pytest.mark.xfail(reason="The precision is too large to be supported on the GPU", raises=IllegalArgumentException))], ids=idfn)
def test_division(data_gen):
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1714,29 +1714,39 @@ object GpuOverrides {
(childExprs.head.dataType, childExprs(1).dataType) match {
case (l: DecimalType, r: DecimalType) =>
val outputType = GpuDivideUtil.decimalDataType(l, r)
// We will never hit a case where outputType.precision < outputType.scale + r.scale.
// So there is no need to protect against that.
// The only two cases in which there is a possibility of the intermediary scale
// exceeding the intermediary precision is when l.precision < l.scale or l
// .precision < 0, both of which aren't possible.
// Proof:
// case 1:
// outputType.precision = p1 - s1 + s2 + s1 + p2 + 1 + 1
// outputType.scale = p1 + s2 + p2 + 1 + 1
// To find out if outputType.precision < outputType.scale simplifies to p1 < s1,
// which is never possible
//
// case 2:
// outputType.precision = p1 - s1 + s2 + 6 + 1
// outputType.scale = 6 + 1
// To find out if outputType.precision < outputType.scale simplifies to p1 < 0
// which is never possible
// Case 1: OutputType.precision doesn't get truncated
// We will never hit a case where outputType.precision < outputType.scale + r.scale.
// So there is no need to protect against that.
// The only two cases in which there is a possibility of the intermediary scale
// exceeding the intermediary precision is when l.precision < l.scale or l
// .precision < 0, both of which aren't possible.
// Proof:
// case 1:
// outputType.precision = p1 - s1 + s2 + s1 + p2 + 1 + 1
// outputType.scale = p1 + s2 + p2 + 1 + 1
// To find out if outputType.precision < outputType.scale simplifies to p1 < s1,
// which is never possible
//
// case 2:
// outputType.precision = p1 - s1 + s2 + 6 + 1
// outputType.scale = 6 + 1
// To find out if outputType.precision < outputType.scale simplifies to p1 < 0
// which is never possible
// Case 2: OutputType.precision gets truncated to 38
// In this case we have to make sure the r.precision + l.scale + r.scale + 1 <= 38
// Otherwise the intermediate result will overflow
// TODO We should revisit the proof one more time after we support 128-bit decimals
val intermediateResult = DecimalType(outputType.precision, outputType.scale + r.scale)
if (intermediateResult.precision > DType.DECIMAL64_MAX_PRECISION) {
willNotWorkOnGpu("The actual output precision of the divide is too large" +
if (l.precision + l.scale + r.scale + 1 > 38) {
willNotWorkOnGpu("The intermediate output precision of the divide is too " +
s"large to be supported on the GPU i.e. Decimal(${outputType.precision}, " +
s"${outputType.scale + r.scale})")
} else {
val intermediateResult =
DecimalType(outputType.precision, outputType.scale + r.scale)
if (intermediateResult.precision > DType.DECIMAL64_MAX_PRECISION) {
willNotWorkOnGpu("The actual output precision of the divide is too large" +
s" to fit on the GPU $intermediateResult")
}
}
case _ => // NOOP
}
Expand Down

0 comments on commit 244b5ca

Please sign in to comment.