From ac564595db41c5be0177e8305d677ae6abed8aae Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Tue, 19 Jul 2022 12:23:07 +0800 Subject: [PATCH 01/11] split test for NaN Signed-off-by: remzi <13716567376yh@gmail.com> --- .../src/main/python/arithmetic_ops_test.py | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index a1f1c244d33..38023e1716e 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1088,27 +1088,39 @@ def _get_overflow_df_2cols(spark, data_types, values, expr): # test interval division overflow, such as interval / 0, Long.MinValue / -1 ... @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('data_type,value_pair', [ - (ByteType(), [timedelta(seconds=1), 0]), - (ShortType(), [timedelta(seconds=1), 0]), - (IntegerType(), [timedelta(seconds=1), 0]), - (LongType(), [timedelta(seconds=1), 0]), - (FloatType(), [timedelta(seconds=1), 0.0]), - (FloatType(), [timedelta(seconds=1), -0.0]), - (DoubleType(), [timedelta(seconds=1), 0.0]), - (DoubleType(), [timedelta(seconds=1), -0.0]), - (FloatType(), [timedelta(seconds=1), float('NaN')]), - (DoubleType(), [timedelta(seconds=1), float('NaN')]), - (FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]), - (DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]), - (FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), - (DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), - (LongType(), [MIN_DAY_TIME_INTERVAL, -1]), - (FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN - (DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN - (IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) + #(ByteType(), [timedelta(seconds=1), 0]), + #(ShortType(), [timedelta(seconds=1), 0]), + #(IntegerType(), [timedelta(seconds=1), 0]), + #(LongType(), [timedelta(seconds=1), 0]), + #(FloatType(), [timedelta(seconds=1), 0.0]), + #(FloatType(), [timedelta(seconds=1), -0.0]), + #(DoubleType(), [timedelta(seconds=1), 0.0]), + #(DoubleType(), [timedelta(seconds=1), -0.0]), + #(FloatType(), [timedelta(seconds=1), float('NaN')]), + #(DoubleType(), [timedelta(seconds=1), float('NaN')]), + #(FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + #(DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + #(FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), + #(DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), + #(LongType(), [MIN_DAY_TIME_INTERVAL, -1]), + #(FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN + #(DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN + #(IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) ], ids=idfn) def test_day_time_interval_division_overflow(data_type, value_pair): assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), conf={}, error_message='ArithmeticException') + + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_type,value_pair', [ + (FloatType(), [timedelta(seconds=1), float('NaN')]), + (DoubleType(), [timedelta(seconds=1), float('NaN')]), +], ids=idfn) +def test_day_time_interval_division_nan(data_type, value_pair): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), + conf={}, + error_message='java.lang.ArithmeticException') From 703f8bd8a76d22bfe4e410b18cd89570d35d81eb Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Tue, 19 Jul 2022 12:28:47 +0800 Subject: [PATCH 02/11] split test for divide by zero Signed-off-by: remzi <13716567376yh@gmail.com> --- .../src/main/python/arithmetic_ops_test.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 38023e1716e..6d0f45a0515 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1088,23 +1088,11 @@ def _get_overflow_df_2cols(spark, data_types, values, expr): # test interval division overflow, such as interval / 0, Long.MinValue / -1 ... @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('data_type,value_pair', [ - #(ByteType(), [timedelta(seconds=1), 0]), - #(ShortType(), [timedelta(seconds=1), 0]), - #(IntegerType(), [timedelta(seconds=1), 0]), - #(LongType(), [timedelta(seconds=1), 0]), - #(FloatType(), [timedelta(seconds=1), 0.0]), - #(FloatType(), [timedelta(seconds=1), -0.0]), - #(DoubleType(), [timedelta(seconds=1), 0.0]), - #(DoubleType(), [timedelta(seconds=1), -0.0]), - #(FloatType(), [timedelta(seconds=1), float('NaN')]), - #(DoubleType(), [timedelta(seconds=1), float('NaN')]), #(FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]), #(DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]), #(FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), #(DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), #(LongType(), [MIN_DAY_TIME_INTERVAL, -1]), - #(FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN - #(DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN #(IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) ], ids=idfn) def test_day_time_interval_division_overflow(data_type, value_pair): @@ -1113,6 +1101,25 @@ def test_day_time_interval_division_overflow(data_type, value_pair): conf={}, error_message='ArithmeticException') +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_type,value_pair', [ + (ByteType(), [timedelta(seconds=1), 0]), + (ShortType(), [timedelta(seconds=1), 0]), + (IntegerType(), [timedelta(seconds=1), 0]), + (LongType(), [timedelta(seconds=1), 0]), + (FloatType(), [timedelta(seconds=1), 0.0]), + (FloatType(), [timedelta(seconds=1), -0.0]), + (DoubleType(), [timedelta(seconds=1), 0.0]), + (DoubleType(), [timedelta(seconds=1), -0.0]), + (FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN + (DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN +], ids=idfn) +def test_day_time_interval_division_divide_by_zero(data_type, value_pair): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), + conf={}, + error_message='SparkArithmeticException: Division by zero.') + @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('data_type,value_pair', [ From ded355011bd25d254c4ce9f24082a8ea91e6bd2d Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Tue, 19 Jul 2022 12:44:18 +0800 Subject: [PATCH 03/11] temp save Signed-off-by: remzi <13716567376yh@gmail.com> --- integration_tests/src/main/python/arithmetic_ops_test.py | 2 +- .../apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 ++++ .../spark/sql/rapids/shims/intervalExpressions.scala | 8 +++++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 6d0f45a0515..9e44905ff00 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1093,7 +1093,7 @@ def _get_overflow_df_2cols(spark, data_types, values, expr): #(FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), #(DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), #(LongType(), [MIN_DAY_TIME_INTERVAL, -1]), - #(IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) + (IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) ], ids=idfn) def test_day_time_interval_division_overflow(data_type, value_pair): assert_gpu_and_cpu_error( diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index a6e9ecd235b..c60a4e4b03f 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -64,4 +64,8 @@ object RapidsErrorUtils { value, toType.precision, toType.scale, context ) } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + QueryExecutionErrors.arithmeticOverflowError("Overflow in integral divide", "try_divide", context) + } } diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index ab601bbb7e5..111a72681e9 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -22,6 +22,7 @@ import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, DType, RoundMod import com.nvidia.spark.rapids.{Arm, BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar} import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant} +import org.apache.spark.sql.rapids.GpuDivModLike.makeZeroScalar import org.apache.spark.sql.types._ object IntervalUtils extends Arm { @@ -260,7 +261,7 @@ object IntervalUtils extends Arm { withResource(rCv.equalTo(negOneScalar)) { isNegOne => withResource(isMin.and(isNegOne)) { invalid => if (BoolUtils.isAnyValidTrue(invalid)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.overflowInIntegralDivideError() } } } @@ -523,6 +524,11 @@ case class GpuDivideDTInterval( } override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = { + withResource(makeZeroScalar(num.getBase.getType)) { zeroScalar => + if (num.getBase.contains(zeroScalar)) { + throw RapidsErrorUtils.divByZeroError(origin) + } + } doColumnar(interval.getBase, num.getBase, num.dataType) } From 759aa31f3a9385d99dcee8d2469245cb5ac767e3 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Tue, 19 Jul 2022 13:09:25 +0800 Subject: [PATCH 04/11] split tests Signed-off-by: remzi <13716567376yh@gmail.com> --- .../src/main/python/arithmetic_ops_test.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 9e44905ff00..6aab090c350 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1088,18 +1088,27 @@ def _get_overflow_df_2cols(spark, data_types, values, expr): # test interval division overflow, such as interval / 0, Long.MinValue / -1 ... @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('data_type,value_pair', [ - #(FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]), - #(DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]), - #(FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), - #(DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), - #(LongType(), [MIN_DAY_TIME_INTERVAL, -1]), + (LongType(), [MIN_DAY_TIME_INTERVAL, -1]), (IntegerType(), [timedelta(microseconds=LONG_MIN), -1]) ], ids=idfn) def test_day_time_interval_division_overflow(data_type, value_pair): assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), conf={}, - error_message='ArithmeticException') + error_message='SparkArithmeticException: Overflow in integral divide.') + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('data_type,value_pair', [ + (FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + (DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]), + (FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]), + (DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]), +], ids=idfn) +def test_day_time_interval_division_round_overflow(data_type, value_pair): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), + conf={}, + error_message='java.lang.ArithmeticException') @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('data_type,value_pair', [ From 5a322b2d75f02719b3b8239701b24f616e11676f Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Tue, 19 Jul 2022 13:15:04 +0800 Subject: [PATCH 05/11] fix lint Signed-off-by: remzi <13716567376yh@gmail.com> --- .../org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index c60a4e4b03f..a43ba0c8a50 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -66,6 +66,8 @@ object RapidsErrorUtils { } def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { - QueryExecutionErrors.arithmeticOverflowError("Overflow in integral divide", "try_divide", context) + QueryExecutionErrors.arithmeticOverflowError( + "Overflow in integral divide", "try_divide", context + ) } } From a9b052b87c02490c0a8fbe14f3c3b3292ea3cb79 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Wed, 20 Jul 2022 09:30:00 +0800 Subject: [PATCH 06/11] add error utils for 3.1 and 3.2 Signed-off-by: remzi <13716567376yh@gmail.com> --- .../org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 ++++ .../org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 ++++ .../org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 ++++ .../org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala | 4 ++++ 4 files changed, 16 insertions(+) diff --git a/sql-plugin/src/main/311until320-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/311until320-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index b5f4cf00981..c598740cf9f 100644 --- a/sql-plugin/src/main/311until320-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/311until320-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -60,4 +60,8 @@ object RapidsErrorUtils { new ArithmeticException(s"${value.toDebugString} cannot be represented as " + s"Decimal(${toType.precision}, ${toType.scale}).") } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + new ArithmeticException("Overflow in integral divide.") + } } diff --git a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index b5f4cf00981..c598740cf9f 100644 --- a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -60,4 +60,8 @@ object RapidsErrorUtils { new ArithmeticException(s"${value.toDebugString} cannot be represented as " + s"Decimal(${toType.precision}, ${toType.scale}).") } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + new ArithmeticException("Overflow in integral divide.") + } } diff --git a/sql-plugin/src/main/320until330-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/320until330-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 1dd6219a0ff..0856945ed8f 100644 --- a/sql-plugin/src/main/320until330-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/320until330-nondb/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -62,5 +62,9 @@ object RapidsErrorUtils { value, toType.precision, toType.scale ) } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + QueryExecutionErrors.overflowInIntegralDivideError() + } } diff --git a/sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index ab15d817065..fc48fb72d3f 100644 --- a/sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -65,4 +65,8 @@ object RapidsErrorUtils { value, toType.precision, toType.scale ) } + + def overflowInIntegralDivideError(context: String = ""): ArithmeticException = { + QueryExecutionErrors.overflowInIntegralDivideError() + } } From dda156037570df504ce30c616c8cefbe5e014456 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Wed, 20 Jul 2022 10:46:04 +0800 Subject: [PATCH 07/11] update the case of divided by zero scalar Signed-off-by: remzi <13716567376yh@gmail.com> --- .../src/main/python/arithmetic_ops_test.py | 18 ++++++++++++++++++ .../sql/rapids/shims/intervalExpressions.scala | 5 +++++ 2 files changed, 23 insertions(+) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 6aab090c350..e8f11186a31 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1076,6 +1076,14 @@ def test_day_time_interval_division_number_no_overflow2(data_gen): # avoid dividing by 0 lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type)))) +def _get_overflow_df_1col(spark, data_type, value, expr): + return spark.createDataFrame( + SparkContext.getOrCreate().parallelize([value]), + StructType([ + StructField('a', data_type) + ]) + ).selectExpr(expr) + def _get_overflow_df_2cols(spark, data_types, values, expr): return spark.createDataFrame( SparkContext.getOrCreate().parallelize([values]), @@ -1129,6 +1137,16 @@ def test_day_time_interval_division_divide_by_zero(data_type, value_pair): conf={}, error_message='SparkArithmeticException: Division by zero.') +@pytest.mark.parametrize('value,zero_literal', [ + (timedelta(seconds=1), '0'), + (timedelta(seconds=1), '0.0f'), + (timedelta(seconds=1), '-0.0f'), +], ids=idfn) +def test_day_time_interval_division_divide_by_zero_scalar(value, zero_literal): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_1col(spark, DayTimeIntervalType(), [value], 'a / ' + zero_literal).collect(), + conf={}, + error_message='SparkArithmeticException: Division by zero.') @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('data_type,value_pair', [ diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index 111a72681e9..8cc6f06ed42 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -520,6 +520,11 @@ case class GpuDivideDTInterval( override def right: Expression = num override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = { + withResource(makeZeroScalar(numScalar.getBase.getType)) { zeroScalar => + if (numScalar.getBase.equals(zeroScalar)) { + throw RapidsErrorUtils.divByZeroError(origin) + } + } doColumnar(interval.getBase, numScalar.getBase, num.dataType) } From 2df9453070c440455e4ab69da40bc89d6339ceed Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Wed, 20 Jul 2022 11:21:42 +0800 Subject: [PATCH 08/11] update other 2 doColumnar functions and add tests Signed-off-by: remzi <13716567376yh@gmail.com> --- .../src/main/python/arithmetic_ops_test.py | 28 ++++++++++++++----- .../rapids/shims/intervalExpressions.scala | 5 ++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index e8f11186a31..c7f6722c2da 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1131,20 +1131,34 @@ def test_day_time_interval_division_round_overflow(data_type, value_pair): (FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN (DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN ], ids=idfn) -def test_day_time_interval_division_divide_by_zero(data_type, value_pair): +def test_day_time_interval_divided_by_zero(data_type, value_pair): assert_gpu_and_cpu_error( df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(), conf={}, error_message='SparkArithmeticException: Division by zero.') -@pytest.mark.parametrize('value,zero_literal', [ - (timedelta(seconds=1), '0'), - (timedelta(seconds=1), '0.0f'), - (timedelta(seconds=1), '-0.0f'), +@pytest.mark.parametrize('zero_literal', ['0', '0.0f', '-0.0f'], ids=idfn) +def test_day_time_interval_divided_by_zero_scalar(zero_literal): + assert_gpu_and_cpu_error( + df_fun=lambda spark: _get_overflow_df_1col(spark, DayTimeIntervalType(), [timedelta(seconds=1)], 'a / ' + zero_literal).collect(), + conf={}, + error_message='SparkArithmeticException: Division by zero.') + +@pytest.mark.parametrize('data_type,value', [ + (ByteType(), 0), + (ShortType(), 0), + (IntegerType(), 0), + (LongType(), 0), + (FloatType(), 0.0), + (FloatType(), -0.0), + (DoubleType(), 0.0), + (DoubleType(), -0.0), + (FloatType(), 0.0), + (DoubleType(), 0.0), ], ids=idfn) -def test_day_time_interval_division_divide_by_zero_scalar(value, zero_literal): +def test_day_time_interval_scalar_divided_by_zero(data_type, value): assert_gpu_and_cpu_error( - df_fun=lambda spark: _get_overflow_df_1col(spark, DayTimeIntervalType(), [value], 'a / ' + zero_literal).collect(), + df_fun=lambda spark: _get_overflow_df_1col(spark, data_type, [value], 'INTERVAL 1 SECOND / a').collect(), conf={}, error_message='SparkArithmeticException: Division by zero.') diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index 8cc6f06ed42..91952bb8613 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -538,6 +538,11 @@ case class GpuDivideDTInterval( } override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = { + withResource(makeZeroScalar(num.getBase.getType)) { zeroScalar => + if (num.getBase.contains(zeroScalar)) { + throw RapidsErrorUtils.divByZeroError(origin) + } + } doColumnar(intervalScalar.getBase, num.getBase, num.dataType) } From 3e4c2006086f18b42bcf7474eb361b04f28cd9f5 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Wed, 20 Jul 2022 12:35:43 +0800 Subject: [PATCH 09/11] update overflow errors Signed-off-by: remzi <13716567376yh@gmail.com> --- .../spark/sql/rapids/shims/intervalExpressions.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index 91952bb8613..785232f4f34 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -35,7 +35,7 @@ object IntervalUtils extends Arm { withResource(longCv.castTo(DType.INT32)) { intResult => withResource(longCv.notEqualTo(intResult)) { notEquals => if (BoolUtils.isAnyValidTrue(notEquals)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } else { intResult.incRefCount() } @@ -48,7 +48,7 @@ object IntervalUtils extends Arm { withResource(Scalar.fromLong(minValue)) { minScalar => withResource(decimal128Cv.lessThan(minScalar)) { lessThanMin => if (BoolUtils.isAnyValidTrue(lessThanMin)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } } } @@ -57,7 +57,7 @@ object IntervalUtils extends Arm { withResource(Scalar.fromLong(maxValue)) { maxScalar => withResource(decimal128Cv.greaterThan(maxScalar)) { greaterThanMax => if (BoolUtils.isAnyValidTrue(greaterThanMax)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } } } @@ -269,13 +269,13 @@ object IntervalUtils extends Arm { case (lCv: ColumnVector, rS: Scalar) => withResource(lCv.equalTo(minScalar)) { isMin => if (getLong(rS) == -1L && BoolUtils.isAnyValidTrue(isMin)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } } case (lS: Scalar, rCv: ColumnVector) => withResource(rCv.equalTo(negOneScalar)) { isNegOne => if (getLong(lS) == min && BoolUtils.isAnyValidTrue(isNegOne)) { - throw new ArithmeticException("overflow occurs") + throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs") } } case (lS: Scalar, rS: Scalar) => From 231661796b570caf95fe713aac50a097a32f70f9 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Wed, 20 Jul 2022 14:17:26 +0800 Subject: [PATCH 10/11] remove redundant test cases Signed-off-by: remzi <13716567376yh@gmail.com> --- integration_tests/src/main/python/arithmetic_ops_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index c7f6722c2da..c6beb263baa 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1153,8 +1153,6 @@ def test_day_time_interval_divided_by_zero_scalar(zero_literal): (FloatType(), -0.0), (DoubleType(), 0.0), (DoubleType(), -0.0), - (FloatType(), 0.0), - (DoubleType(), 0.0), ], ids=idfn) def test_day_time_interval_scalar_divided_by_zero(data_type, value): assert_gpu_and_cpu_error( From f64f41b46f12056a1f22f07b6264be0d71e9ec15 Mon Sep 17 00:00:00 2001 From: remzi <13716567376yh@gmail.com> Date: Wed, 20 Jul 2022 18:57:23 +0800 Subject: [PATCH 11/11] skip the test if before version 330 Signed-off-by: remzi <13716567376yh@gmail.com> --- integration_tests/src/main/python/arithmetic_ops_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index c6beb263baa..f335fc0016a 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -1137,6 +1137,7 @@ def test_day_time_interval_divided_by_zero(data_type, value_pair): conf={}, error_message='SparkArithmeticException: Division by zero.') +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('zero_literal', ['0', '0.0f', '-0.0f'], ids=idfn) def test_day_time_interval_divided_by_zero_scalar(zero_literal): assert_gpu_and_cpu_error( @@ -1144,6 +1145,7 @@ def test_day_time_interval_divided_by_zero_scalar(zero_literal): conf={}, error_message='SparkArithmeticException: Division by zero.') +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @pytest.mark.parametrize('data_type,value', [ (ByteType(), 0), (ShortType(), 0),