diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index a1f1c244d33..f335fc0016a 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1076,6 +1076,14 @@ def test_day_time_interval_division_number_no_overflow2(data_gen): # avoid dividing by 0 lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type)))) +def _get_overflow_df_1col(spark, data_type, value, expr): + return spark.createDataFrame( + SparkContext.getOrCreate().parallelize([value]), + StructType([ + StructField('a', data_type) + ]) + ).selectExpr(expr) + def _get_overflow_df_2cols(spark, data_types, values, expr): return spark.createDataFrame( SparkContext.getOrCreate().parallelize([values]), @@ -1086,6 +1094,30 @@ def _get_overflow_df_2cols(spark, data_types, values, expr): ).selectExpr(expr) # test interval division overflow, such as interval / 0, Long.MinValue / -1 ... +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_type,value_pair', [ + (LongType(), [MIN_DAY_TIME_INTERVAL, -1]), + (IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) +], ids=idfn) +def test_day_time_interval_division_overflow(data_type, value_pair): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), + conf={}, + error_message='SparkArithmeticException: Overflow in integral divide.') + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_type,value_pair', [ + (FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + (DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + (FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), + (DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), +], ids=idfn) +def test_day_time_interval_division_round_overflow(data_type, value_pair): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), + conf={}, + error_message='java.lang.ArithmeticException') + @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('data_type,value_pair', [ (ByteType(), [timedelta(seconds=1), 0]), @@ -1096,19 +1128,47 @@ def _get_overflow_df_2cols(spark, data_types, values, expr): (FloatType(), [timedelta(seconds=1), -0.0]), (DoubleType(), [timedelta(seconds=1), 0.0]), (DoubleType(), [timedelta(seconds=1), -0.0]), - (FloatType(), [timedelta(seconds=1), float('NaN')]), - (DoubleType(), [timedelta(seconds=1), float('NaN')]), - (FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]), - (DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]), - (FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), - (DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), - (LongType(), [MIN_DAY_TIME_INTERVAL, -1]), (FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN (DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN - (IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) ], ids=idfn) -def test_day_time_interval_division_overflow(data_type, value_pair): +def test_day_time_interval_divided_by_zero(data_type, value_pair): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), + conf={}, + error_message='SparkArithmeticException: Division by zero.') + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('zero_literal', ['0', '0.0f', '-0.0f'], ids=idfn) +def test_day_time_interval_divided_by_zero_scalar(zero_literal): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_1col(spark, DayTimeIntervalType(), [timedelta(seconds=1)], 'a / ' + zero_literal).collect(), + conf={}, + error_message='SparkArithmeticException: Division by zero.') + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_type,value', [ + (ByteType(), 0), + (ShortType(), 0), + (IntegerType(), 0), + (LongType(), 0), + (FloatType(), 0.0), + (FloatType(), -0.0), + (DoubleType(), 0.0), + (DoubleType(), -0.0), +], ids=idfn) +def test_day_time_interval_scalar_divided_by_zero(data_type, value): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_1col(spark, data_type, [value], 'INTERVAL 1 SECOND / a').collect(), + conf={}, + error_message='SparkArithmeticException: Division by zero.') + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_type,value_pair', [ + (FloatType(), [timedelta(seconds=1), float('NaN')]), + (DoubleType(), [timedelta(seconds=1), float('NaN')]), +], ids=idfn) +def test_day_time_interval_division_nan(data_type, value_pair): assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), conf={}, - error_message='ArithmeticException') + error_message='java.lang.ArithmeticException') diff --git a/sql-plugin/src/main/311until320-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/311until320-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index b5f4cf00981..c598740cf9f 100644 --- a/sql-plugin/src/main/311until320-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/311until320-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -60,4 +60,8 @@ object RapidsErrorUtils { new ArithmeticException(s"${value.toDebugString} cannot be represented as " + s"Decimal(${toType.precision}, ${toType.scale}).") } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + new ArithmeticException("Overflow in integral divide.") + } } diff --git a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index b5f4cf00981..c598740cf9f 100644 --- a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -60,4 +60,8 @@ object RapidsErrorUtils { new ArithmeticException(s"${value.toDebugString} cannot be represented as " + s"Decimal(${toType.precision}, ${toType.scale}).") } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + new ArithmeticException("Overflow in integral divide.") + } } diff --git a/sql-plugin/src/main/320until330-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/320until330-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 1dd6219a0ff..0856945ed8f 100644 --- a/sql-plugin/src/main/320until330-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/320until330-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -62,5 +62,9 @@ object RapidsErrorUtils { value, toType.precision, toType.scale ) } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + QueryExecutionErrors.overflowInIntegralDivideError() + } } diff --git a/sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index ab15d817065..fc48fb72d3f 100644 --- a/sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -65,4 +65,8 @@ object RapidsErrorUtils { value, toType.precision, toType.scale ) } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + QueryExecutionErrors.overflowInIntegralDivideError() + } } diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index a6e9ecd235b..a43ba0c8a50 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -64,4 +64,10 @@ object RapidsErrorUtils { value, toType.precision, toType.scale, context ) } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + QueryExecutionErrors.arithmeticOverflowError( + "Overflow in integral divide", "try_divide", context + ) + } } diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index ab601bbb7e5..785232f4f34 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -22,6 +22,7 @@ import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, DType, RoundMod import com.nvidia.spark.rapids.{Arm, BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar} import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant} +import org.apache.spark.sql.rapids.GpuDivModLike.makeZeroScalar import org.apache.spark.sql.types._ object IntervalUtils extends Arm { @@ -34,7 +35,7 @@ object IntervalUtils extends Arm { withResource(longCv.castTo(DType.INT32)) { intResult => withResource(longCv.notEqualTo(intResult)) { notEquals => if (BoolUtils.isAnyValidTrue(notEquals)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } else { intResult.incRefCount() } @@ -47,7 +48,7 @@ object IntervalUtils extends Arm { withResource(Scalar.fromLong(minValue)) { minScalar => withResource(decimal128Cv.lessThan(minScalar)) { lessThanMin => if (BoolUtils.isAnyValidTrue(lessThanMin)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } } } @@ -56,7 +57,7 @@ object IntervalUtils extends Arm { withResource(Scalar.fromLong(maxValue)) { maxScalar => withResource(decimal128Cv.greaterThan(maxScalar)) { greaterThanMax => if (BoolUtils.isAnyValidTrue(greaterThanMax)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } } } @@ -260,7 +261,7 @@ object IntervalUtils extends Arm { withResource(rCv.equalTo(negOneScalar)) { isNegOne => withResource(isMin.and(isNegOne)) { invalid => if (BoolUtils.isAnyValidTrue(invalid)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.overflowInIntegralDivideError() } } } @@ -268,13 +269,13 @@ object IntervalUtils extends Arm { case (lCv: ColumnVector, rS: Scalar) => withResource(lCv.equalTo(minScalar)) { isMin => if (getLong(rS) == -1L && BoolUtils.isAnyValidTrue(isMin)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } } case (lS: Scalar, rCv: ColumnVector) => withResource(rCv.equalTo(negOneScalar)) { isNegOne => if (getLong(lS) == min && BoolUtils.isAnyValidTrue(isNegOne)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } } case (lS: Scalar, rS: Scalar) => @@ -519,14 +520,29 @@ case class GpuDivideDTInterval( override def right: Expression = num override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = { + withResource(makeZeroScalar(numScalar.getBase.getType)) { zeroScalar => + if (numScalar.getBase.equals(zeroScalar)) { + throw RapidsErrorUtils.divByZeroError(origin) + } + } doColumnar(interval.getBase, numScalar.getBase, num.dataType) } override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = { + withResource(makeZeroScalar(num.getBase.getType)) { zeroScalar => + if (num.getBase.contains(zeroScalar)) { + throw RapidsErrorUtils.divByZeroError(origin) + } + } doColumnar(interval.getBase, num.getBase, num.dataType) } override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = { + withResource(makeZeroScalar(num.getBase.getType)) { zeroScalar => + if (num.getBase.contains(zeroScalar)) { + throw RapidsErrorUtils.divByZeroError(origin) + } + } doColumnar(intervalScalar.getBase, num.getBase, num.dataType) }