diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 3951e822776..a44274af01b 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -20,6 +20,7 @@ from pyspark.sql.types import * from spark_session import with_cpu_session, with_gpu_session, with_spark_session, is_before_spark_311 import pyspark.sql.functions as f +from pyspark.sql.utils import IllegalArgumentException decimal_gens_not_max_prec = [decimal_gen_neg_scale, decimal_gen_scale_precision, decimal_gen_same_scale_precision, decimal_gen_64bit] @@ -69,7 +70,10 @@ def test_multiplication_mixed(lhs, rhs): f.col('a') * f.col('b')), conf=allow_negative_scale_of_decimal_conf) -@pytest.mark.parametrize('data_gen', [double_gen, decimal_gen_neg_scale, DecimalGen(6, 3), DecimalGen(5, 5), DecimalGen(6, 0)], ids=idfn) +@pytest.mark.parametrize('data_gen', [double_gen, decimal_gen_neg_scale, DecimalGen(6, 3), + DecimalGen(5, 5), DecimalGen(6, 0), +pytest.param(DecimalGen(38, 21), marks=pytest.mark.xfail(reason="The precision is too large to be supported on the GPU", raises=IllegalArgumentException)), +pytest.param(DecimalGen(21, 17), marks=pytest.mark.xfail(reason="The precision is too large to be supported on the GPU", raises=IllegalArgumentException))], ids=idfn) def test_division(data_gen): data_type = data_gen.data_type assert_gpu_and_cpu_are_equal_collect( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 5cad67b49fc..1d32fe53b33 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1714,29 +1714,39 @@ object GpuOverrides { (childExprs.head.dataType, childExprs(1).dataType) match { case (l: DecimalType, r: DecimalType) => val outputType = GpuDivideUtil.decimalDataType(l, r) - // We will never hit a case where outputType.precision < outputType.scale + r.scale. - // So there is no need to protect against that. - // The only two cases in which there is a possibility of the intermediary scale - // exceeding the intermediary precision is when l.precision < l.scale or l - // .precision < 0, both of which aren't possible. - // Proof: - // case 1: - // outputType.precision = p1 - s1 + s2 + s1 + p2 + 1 + 1 - // outputType.scale = p1 + s2 + p2 + 1 + 1 - // To find out if outputType.precision < outputType.scale simplifies to p1 < s1, - // which is never possible - // - // case 2: - // outputType.precision = p1 - s1 + s2 + 6 + 1 - // outputType.scale = 6 + 1 - // To find out if outputType.precision < outputType.scale simplifies to p1 < 0 - // which is never possible + // Case 1: OutputType.precision doesn't get truncated + // We will never hit a case where outputType.precision < outputType.scale + r.scale. + // So there is no need to protect against that. + // The only two cases in which there is a possibility of the intermediary scale + // exceeding the intermediary precision is when l.precision < l.scale or l + // .precision < 0, both of which aren't possible. + // Proof: + // case 1: + // outputType.precision = p1 - s1 + s2 + s1 + p2 + 1 + 1 + // outputType.scale = p1 + s2 + p2 + 1 + 1 + // To find out if outputType.precision < outputType.scale simplifies to p1 < s1, + // which is never possible // + // case 2: + // outputType.precision = p1 - s1 + s2 + 6 + 1 + // outputType.scale = 6 + 1 + // To find out if outputType.precision < outputType.scale simplifies to p1 < 0 + // which is never possible + // Case 2: OutputType.precision gets truncated to 38 + // In this case we have to make sure the r.precision + l.scale + r.scale + 1 <= 38 + // Otherwise the intermediate result will overflow // TODO We should revisit the proof one more time after we support 128-bit decimals - val intermediateResult = DecimalType(outputType.precision, outputType.scale + r.scale) - if (intermediateResult.precision > DType.DECIMAL64_MAX_PRECISION) { - willNotWorkOnGpu("The actual output precision of the divide is too large" + + if (l.precision + l.scale + r.scale + 1 > 38) { + willNotWorkOnGpu("The intermediate output precision of the divide is too " + + s"large to be supported on the GPU i.e. Decimal(${outputType.precision}, " + + s"${outputType.scale + r.scale})") + } else { + val intermediateResult = + DecimalType(outputType.precision, outputType.scale + r.scale) + if (intermediateResult.precision > DType.DECIMAL64_MAX_PRECISION) { + willNotWorkOnGpu("The actual output precision of the divide is too large" + s" to fit on the GPU $intermediateResult") + } } case _ => // NOOP }