Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the interval division to throw same type exceptions as Spark #6019

80 changes: 70 additions & 10 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,14 @@ def test_day_time_interval_division_number_no_overflow2(data_gen):
# avoid dividing by 0
lambda spark: gen_df(spark, gen_list).selectExpr("_c1 / case when _c2 = 0 then cast(1 as {}) else _c2 end".format(to_cast_string(data_gen.data_type))))

def _get_overflow_df_1col(spark, data_type, value, expr):
return spark.createDataFrame(
SparkContext.getOrCreate().parallelize([value]),
StructType([
StructField('a', data_type)
])
).selectExpr(expr)

def _get_overflow_df_2cols(spark, data_types, values, expr):
return spark.createDataFrame(
SparkContext.getOrCreate().parallelize([values]),
Expand All @@ -1086,6 +1094,30 @@ def _get_overflow_df_2cols(spark, data_types, values, expr):
).selectExpr(expr)

# test interval division overflow, such as interval / 0, Long.MinValue / -1 ...
@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value_pair', [
(LongType(), [MIN_DAY_TIME_INTERVAL, -1]),
(IntegerType(), [timedelta(microseconds=LONG_MIN), -1])
], ids=idfn)
def test_day_time_interval_division_overflow(data_type, value_pair):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(),
conf={},
error_message='SparkArithmeticException: Overflow in integral divide.')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value_pair', [
(FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
], ids=idfn)
def test_day_time_interval_division_round_overflow(data_type, value_pair):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(),
conf={},
error_message='java.lang.ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value_pair', [
(ByteType(), [timedelta(seconds=1), 0]),
Expand All @@ -1096,19 +1128,47 @@ def _get_overflow_df_2cols(spark, data_types, values, expr):
(FloatType(), [timedelta(seconds=1), -0.0]),
(DoubleType(), [timedelta(seconds=1), 0.0]),
(DoubleType(), [timedelta(seconds=1), -0.0]),
(FloatType(), [timedelta(seconds=1), float('NaN')]),
(DoubleType(), [timedelta(seconds=1), float('NaN')]),
(FloatType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MAX_DAY_TIME_INTERVAL, 0.1]),
(FloatType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
(DoubleType(), [MIN_DAY_TIME_INTERVAL, 0.1]),
(LongType(), [MIN_DAY_TIME_INTERVAL, -1]),
(FloatType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN
(DoubleType(), [timedelta(seconds=0), 0.0]), # 0 / 0 = NaN
(IntegerType(), [timedelta(microseconds=LONG_MIN), -1])
], ids=idfn)
def test_day_time_interval_division_overflow(data_type, value_pair):
def test_day_time_interval_divided_by_zero(data_type, value_pair):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(),
conf={},
error_message='SparkArithmeticException: Division by zero.')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('zero_literal', ['0', '0.0f', '-0.0f'], ids=idfn)
def test_day_time_interval_divided_by_zero_scalar(zero_literal):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_1col(spark, DayTimeIntervalType(), [timedelta(seconds=1)], 'a / ' + zero_literal).collect(),
conf={},
error_message='SparkArithmeticException: Division by zero.')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value', [
(ByteType(), 0),
(ShortType(), 0),
(IntegerType(), 0),
(LongType(), 0),
(FloatType(), 0.0),
(FloatType(), -0.0),
(DoubleType(), 0.0),
(DoubleType(), -0.0),
], ids=idfn)
def test_day_time_interval_scalar_divided_by_zero(data_type, value):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_1col(spark, data_type, [value], 'INTERVAL 1 SECOND / a').collect(),
conf={},
error_message='SparkArithmeticException: Division by zero.')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('data_type,value_pair', [
(FloatType(), [timedelta(seconds=1), float('NaN')]),
(DoubleType(), [timedelta(seconds=1), float('NaN')]),
], ids=idfn)
def test_day_time_interval_division_nan(data_type, value_pair):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df_2cols(spark, [DayTimeIntervalType(), data_type], value_pair, 'a / b').collect(),
conf={},
error_message='ArithmeticException')
error_message='java.lang.ArithmeticException')
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ object RapidsErrorUtils {
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
new ArithmeticException("Overflow in integral divide.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ object RapidsErrorUtils {
new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
s"Decimal(${toType.precision}, ${toType.scale}).")
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
new ArithmeticException("Overflow in integral divide.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,9 @@ object RapidsErrorUtils {
value, toType.precision, toType.scale
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
QueryExecutionErrors.overflowInIntegralDivideError()
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,8 @@ object RapidsErrorUtils {
value, toType.precision, toType.scale
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
QueryExecutionErrors.overflowInIntegralDivideError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,10 @@ object RapidsErrorUtils {
value, toType.precision, toType.scale, context
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
HaoYang670 marked this conversation as resolved.
Show resolved Hide resolved
QueryExecutionErrors.arithmeticOverflowError(
"Overflow in integral divide", "try_divide", context
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, DType, RoundMod
import com.nvidia.spark.rapids.{Arm, BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar}

import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant}
import org.apache.spark.sql.rapids.GpuDivModLike.makeZeroScalar
import org.apache.spark.sql.types._

object IntervalUtils extends Arm {
Expand All @@ -34,7 +35,7 @@ object IntervalUtils extends Arm {
withResource(longCv.castTo(DType.INT32)) { intResult =>
withResource(longCv.notEqualTo(intResult)) { notEquals =>
if (BoolUtils.isAnyValidTrue(notEquals)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
} else {
intResult.incRefCount()
}
Expand All @@ -47,7 +48,7 @@ object IntervalUtils extends Arm {
withResource(Scalar.fromLong(minValue)) { minScalar =>
withResource(decimal128Cv.lessThan(minScalar)) { lessThanMin =>
if (BoolUtils.isAnyValidTrue(lessThanMin)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
}
}
}
Expand All @@ -56,7 +57,7 @@ object IntervalUtils extends Arm {
withResource(Scalar.fromLong(maxValue)) { maxScalar =>
withResource(decimal128Cv.greaterThan(maxScalar)) { greaterThanMax =>
if (BoolUtils.isAnyValidTrue(greaterThanMax)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
}
}
}
Expand Down Expand Up @@ -260,21 +261,21 @@ object IntervalUtils extends Arm {
withResource(rCv.equalTo(negOneScalar)) { isNegOne =>
withResource(isMin.and(isNegOne)) { invalid =>
if (BoolUtils.isAnyValidTrue(invalid)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.overflowInIntegralDivideError()
}
}
}
}
case (lCv: ColumnVector, rS: Scalar) =>
withResource(lCv.equalTo(minScalar)) { isMin =>
if (getLong(rS) == -1L && BoolUtils.isAnyValidTrue(isMin)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
}
}
case (lS: Scalar, rCv: ColumnVector) =>
withResource(rCv.equalTo(negOneScalar)) { isNegOne =>
if (getLong(lS) == min && BoolUtils.isAnyValidTrue(isNegOne)) {
throw new ArithmeticException("overflow occurs")
throw RapidsErrorUtils.arithmeticOverflowError("overflow occurs")
}
}
case (lS: Scalar, rS: Scalar) =>
Expand Down Expand Up @@ -519,14 +520,29 @@ case class GpuDivideDTInterval(
override def right: Expression = num

override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar): ColumnVector = {
withResource(makeZeroScalar(numScalar.getBase.getType)) { zeroScalar =>
if (numScalar.getBase.equals(zeroScalar)) {
throw RapidsErrorUtils.divByZeroError(origin)
}
}
doColumnar(interval.getBase, numScalar.getBase, num.dataType)
}

override def doColumnar(interval: GpuColumnVector, num: GpuColumnVector): ColumnVector = {
withResource(makeZeroScalar(num.getBase.getType)) { zeroScalar =>
firestarman marked this conversation as resolved.
Show resolved Hide resolved
if (num.getBase.contains(zeroScalar)) {
throw RapidsErrorUtils.divByZeroError(origin)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use a new error SparkArithmeticException, it's better to remove the original error ArithmeticException.
Can we just update the original check?

Original check

    if (IntervalUtils.hasZero(q)) {
      throw new ArithmeticException("overflow: interval / zero")
    }

Original error:

java.lang.ArithmeticException: overflow: interval / zero
	at org.apache.spark.sql.rapids.shims.IntervalUtils$.divWithHalfUpModeWithOverflowCheck(intervalExpressions.scala:246)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also check another path, the numScalar can be zero.

override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also check another path, the numScalar can be zero.

override def doColumnar(interval: GpuColumnVector, numScalar: GpuScalar)

Updated!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use a new error SparkArithmeticException, it's better to remove the original error ArithmeticException. Can we just update the original check?

Original check

    if (IntervalUtils.hasZero(q)) {
      throw new ArithmeticException("overflow: interval / zero")
    }

Original error:

java.lang.ArithmeticException: overflow: interval / zero
	at org.apache.spark.sql.rapids.shims.IntervalUtils$.divWithHalfUpModeWithOverflowCheck(intervalExpressions.scala:246)

Updated!.

BTW, I don't update this error:

  def divWithHalfUpModeWithOverflowCheck(p: BinaryOperable, q: BinaryOperable): ColumnVector = {
    // 1. overflow check q is 0
    if (IntervalUtils.hasZero(q)) {
      throw new ArithmeticException("overflow: interval / zero")
    }
    ...

because we don't have the value origin: Origin in this context

}
}
doColumnar(interval.getBase, num.getBase, num.dataType)
}

override def doColumnar(intervalScalar: GpuScalar, num: GpuColumnVector): ColumnVector = {
withResource(makeZeroScalar(num.getBase.getType)) { zeroScalar =>
if (num.getBase.contains(zeroScalar)) {
throw RapidsErrorUtils.divByZeroError(origin)
}
}
doColumnar(intervalScalar.getBase, num.getBase, num.dataType)
}

Expand Down