Skip to content

Commit

Permalink
Update the interval division to throw the same type exceptions as Spa…
Browse files Browse the repository at this point in the history
…rk (#6019)

* split test for NaN

Signed-off-by: remzi <[email protected]>

* split test for divide by zero

Signed-off-by: remzi <[email protected]>

* temp save

Signed-off-by: remzi <[email protected]>

* split tests

Signed-off-by: remzi <[email protected]>

* fix lint

Signed-off-by: remzi <[email protected]>

* add error utils for 3.1 and 3.2

Signed-off-by: remzi <[email protected]>

* update the case of divided by zero scalar

Signed-off-by: remzi <[email protected]>

* update other 2 doColumnar functions and add tests

Signed-off-by: remzi <[email protected]>

* update overflow errors

Signed-off-by: remzi <[email protected]>

* remove redundant test cases

Signed-off-by: remzi <[email protected]>

* skip the test if before version 330

Signed-off-by: remzi <[email protected]>
  • Loading branch information
HaoYang670 authored Jul 21, 2022
1 parent 392dba0 commit 071cb2f
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 16 deletions.
80 changes: 70 additions & 10 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,14 @@ def test_day_time_interval_division_number_no_overflow2(data_gen):
# avoid dividing by 0
lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type))))

def _get_overflow_df_1col(spark, data_type, value, expr):
return spark.createDataFrame(
SparkContext.getOrCreate().parallelize([value]),
StructType([
StructField('a', data_type)
])
).selectExpr(expr)

def _get_overflow_df_2cols(spark, data_types, values, expr):
return spark.createDataFrame(
SparkContext.getOrCreate().parallelize([values]),
Expand All @@ -1086,6 +1094,30 @@ def _get_overflow_df_2cols(spark, data_types, values, expr):
).selectExpr(expr)

# test interval division overflow, such as interval / 0, Long.MinValue / -1 ...
@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value_pair', [
(LongType(), [MIN_DAY_TIME_INTERVAL, -1]),
(IntegerType(), [timedelta(microseconds=LONG_MIN), -1])
], ids=idfn)
def test_day_time_interval_division_overflow(data_type, value_pair):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(),
conf={},
error_message='SparkArithmeticException: Overflow in integral divide.')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value_pair', [
(FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
], ids=idfn)
def test_day_time_interval_division_round_overflow(data_type, value_pair):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(),
conf={},
error_message='java.lang.ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value_pair', [
(ByteType(), [timedelta(seconds=1), 0]),
Expand All @@ -1096,19 +1128,47 @@ def _get_overflow_df_2cols(spark, data_types, values, expr):
(FloatType(), [timedelta(seconds=1), -0.0]),
(DoubleType(), [timedelta(seconds=1), 0.0]),
(DoubleType(), [timedelta(seconds=1), -0.0]),
(FloatType(), [timedelta(seconds=1), float('NaN')]),
(DoubleType(), [timedelta(seconds=1), float('NaN')]),
(FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
(LongType(), [MIN_DAY_TIME_INTERVAL, -1]),
(FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN
(DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN
(IntegerType(), [timedelta(microseconds=LONG_MIN), -1])
], ids=idfn)
def test_day_time_interval_division_overflow(data_type, value_pair):
def test_day_time_interval_divided_by_zero(data_type, value_pair):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(),
conf={},
error_message='SparkArithmeticException: Division by zero.')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('zero_literal', ['0', '0.0f', '-0.0f'], ids=idfn)
def test_day_time_interval_divided_by_zero_scalar(zero_literal):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_1col(spark, DayTimeIntervalType(), [timedelta(seconds=1)], 'a / ' + zero_literal).collect(),
conf={},
error_message='SparkArithmeticException: Division by zero.')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value', [
(ByteType(), 0),
(ShortType(), 0),
(IntegerType(), 0),
(LongType(), 0),
(FloatType(), 0.0),
(FloatType(), -0.0),
(DoubleType(), 0.0),
(DoubleType(), -0.0),
], ids=idfn)
def test_day_time_interval_scalar_divided_by_zero(data_type, value):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_1col(spark, data_type, [value], 'INTERVAL 1 SECOND / a').collect(),
conf={},
error_message='SparkArithmeticException: Division by zero.')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value_pair', [
(FloatType(), [timedelta(seconds=1), float('NaN')]),
(DoubleType(), [timedelta(seconds=1), float('NaN')]),
], ids=idfn)
def test_day_time_interval_division_nan(data_type, value_pair):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(),
conf={},
error_message='ArithmeticException')
error_message='java.lang.ArithmeticException')
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ object RapidsErrorUtils {
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
new ArithmeticException("Overflow in integral divide.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ object RapidsErrorUtils {
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
new ArithmeticException("Overflow in integral divide.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,9 @@ object RapidsErrorUtils {
value, toType.precision, toType.scale
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
QueryExecutionErrors.overflowInIntegralDivideError()
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,8 @@ object RapidsErrorUtils {
value, toType.precision, toType.scale
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
QueryExecutionErrors.overflowInIntegralDivideError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,10 @@ object RapidsErrorUtils {
value, toType.precision, toType.scale, context
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(
"Overflow in integral divide", "try_divide", context
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, DType, RoundMod
import com.nvidia.spark.rapids.{Arm, BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar}

import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant}
import org.apache.spark.sql.rapids.GpuDivModLike.makeZeroScalar
import org.apache.spark.sql.types._

object IntervalUtils extends Arm {
Expand All @@ -34,7 +35,7 @@ object IntervalUtils extends Arm {
withResource(longCv.castTo(DType.INT32)) { intResult =>
withResource(longCv.notEqualTo(intResult)) { notEquals =>
if (BoolUtils.isAnyValidTrue(notEquals)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
} else {
intResult.incRefCount()
}
Expand All @@ -47,7 +48,7 @@ object IntervalUtils extends Arm {
withResource(Scalar.fromLong(minValue)) { minScalar =>
withResource(decimal128Cv.lessThan(minScalar)) { lessThanMin =>
if (BoolUtils.isAnyValidTrue(lessThanMin)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
}
}
}
Expand All @@ -56,7 +57,7 @@ object IntervalUtils extends Arm {
withResource(Scalar.fromLong(maxValue)) { maxScalar =>
withResource(decimal128Cv.greaterThan(maxScalar)) { greaterThanMax =>
if (BoolUtils.isAnyValidTrue(greaterThanMax)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
}
}
}
Expand Down Expand Up @@ -260,21 +261,21 @@ object IntervalUtils extends Arm {
withResource(rCv.equalTo(negOneScalar)) { isNegOne =>
withResource(isMin.and(isNegOne)) { invalid =>
if (BoolUtils.isAnyValidTrue(invalid)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.overflowInIntegralDivideError()
}
}
}
}
case (lCv: ColumnVector, rS: Scalar) =>
withResource(lCv.equalTo(minScalar)) { isMin =>
if (getLong(rS) == -1L && BoolUtils.isAnyValidTrue(isMin)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
}
}
case (lS: Scalar, rCv: ColumnVector) =>
withResource(rCv.equalTo(negOneScalar)) { isNegOne =>
if (getLong(lS) == min && BoolUtils.isAnyValidTrue(isNegOne)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
}
}
case (lS: Scalar, rS: Scalar) =>
Expand Down Expand Up @@ -519,14 +520,29 @@ case class GpuDivideDTInterval(
override def right: Expression = num

override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = {
withResource(makeZeroScalar(numScalar.getBase.getType)) { zeroScalar =>
if (numScalar.getBase.equals(zeroScalar)) {
throw RapidsErrorUtils.divByZeroError(origin)
}
}
doColumnar(interval.getBase, numScalar.getBase, num.dataType)
}

override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = {
withResource(makeZeroScalar(num.getBase.getType)) { zeroScalar =>
if (num.getBase.contains(zeroScalar)) {
throw RapidsErrorUtils.divByZeroError(origin)
}
}
doColumnar(interval.getBase, num.getBase, num.dataType)
}

override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = {
withResource(makeZeroScalar(num.getBase.getType)) { zeroScalar =>
if (num.getBase.contains(zeroScalar)) {
throw RapidsErrorUtils.divByZeroError(origin)
}
}
doColumnar(intervalScalar.getBase, num.getBase, num.dataType)
}

Expand Down

0 comments on commit 071cb2f

Please sign in to comment.