Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DecimalType in Remainder for Spark 3.4 and DB 11.3 [databricks] #8302

Merged
merged 6 commits into from
May 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 22 additions & 51 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def test_int_division_mixed(lhs, rhs):
'a DIV b'))

@pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn)
@pytest.mark.skipif(is_databricks113_or_later() or is_spark_340_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/7595')
def test_mod(data_gen):
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
Expand All @@ -283,20 +282,6 @@ def test_mod(data_gen):
f.col('b') % f.lit(None).cast(data_type),
f.col('a') % f.col('b')))

# This test is only added because we are skipping test_mod for spark 3.4 and databricks 11.3 because of https://github.com/NVIDIA/spark-rapids/issues/7595
# Once that is resolved we should remove this test and not skip test_mod for spark 3.4 and db 11.3
@pytest.mark.parametrize('data_gen', numeric_gens, ids=idfn)
@pytest.mark.skipif(not is_databricks113_or_later() and is_before_spark_340(), reason='https://github.com/NVIDIA/spark-rapids/issues/7595')
def test_mod_db11_3(data_gen):
data_type = data_gen.data_type
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).select(
f.col('a') % f.lit(100).cast(data_type),
f.lit(-12).cast(data_type) % f.col('b'),
f.lit(None).cast(data_type) % f.col('a'),
f.col('b') % f.lit(None).cast(data_type),
f.col('a') % f.col('b')))

# pmod currently falls back for Decimal(precision=38)
# https://github.com/NVIDIA/spark-rapids/issues/6336
# only testing numeric_gens because of https://github.com/NVIDIA/spark-rapids/issues/7553
Expand Down Expand Up @@ -397,49 +382,35 @@ def test_mod_pmod_by_zero_not_ansi(data_gen):
'pmod(a, cast(0 as {}))'.format(string_type),
'pmod(cast(-12 as {}), cast(0 as {}))'.format(string_type, string_type)),
{'spark.sql.ansi.enabled': 'false'})
# Skip decimal tests for mod on spark 3.4 and databricks 11.3, reason=https://github.com/NVIDIA/spark-rapids/issues/7595
if is_before_spark_340() or not is_databricks113_or_later():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'a % (cast(0 as {}))'.format(string_type),
'cast(-12 as {}) % cast(0 as {})'.format(string_type, string_type)),
{'spark.sql.ansi.enabled': 'false'})

mod_mixed_decimals_lhs = [DecimalGen(6, 5), DecimalGen(6, 4), DecimalGen(5, 4), DecimalGen(5, 3), DecimalGen(4, 2),
DecimalGen(3, -2), DecimalGen(16, 7), DecimalGen(19, 0), DecimalGen(30, 10)]
mod_mixed_decimals_rhs = [DecimalGen(6, 3), DecimalGen(10, -2), DecimalGen(15, 3), DecimalGen(30, 12),
DecimalGen(3, -3), DecimalGen(27, 7), DecimalGen(20, -3)]
mod_mixed_lhs = [byte_gen, short_gen, int_gen, long_gen]
mod_mixed_lhs.extend(pytest.param(t, marks=pytest.mark.skipif(is_databricks113_or_later() or not is_before_spark_340(),
reason='https://github.com/NVIDIA/spark-rapids/issues/7595')) for t in mod_mixed_decimals_lhs)
mod_mixed_rhs = [byte_gen, short_gen, int_gen, long_gen]
mod_mixed_rhs.extend(pytest.param(t, marks=pytest.mark.skipif(is_databricks113_or_later() or not is_before_spark_340(),
reason='https://github.com/NVIDIA/spark-rapids/issues/7595')) for t in mod_mixed_decimals_rhs)
@pytest.mark.parametrize('lhs', mod_mixed_lhs, ids=idfn)
@pytest.mark.parametrize('rhs', mod_mixed_rhs, ids=idfn)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'a % (cast(0 as {}))'.format(string_type),
'cast(-12 as {}) % cast(0 as {})'.format(string_type, string_type)),
{'spark.sql.ansi.enabled': 'false'})

@pytest.mark.parametrize('lhs', [byte_gen, short_gen, int_gen, long_gen, DecimalGen(6, 5),
DecimalGen(6, 4), DecimalGen(5, 4), DecimalGen(5, 3), DecimalGen(4, 2), DecimalGen(3, -2),
DecimalGen(16, 7), DecimalGen(19, 0), DecimalGen(30, 10)], ids=idfn)
@pytest.mark.parametrize('rhs', [byte_gen, short_gen, int_gen, long_gen, DecimalGen(6, 3),
DecimalGen(10, -2), DecimalGen(15, 3), DecimalGen(30, 12), DecimalGen(3, -3),
DecimalGen(27, 7), DecimalGen(20, -3)], ids=idfn)
def test_mod_mixed(lhs, rhs):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"a % b"))

@allow_non_gpu('ProjectExec')
@pytest.mark.skipif(not is_databricks113_or_later() or is_spark_340_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/7595')
@pytest.mark.parametrize('lhs', mod_mixed_decimals_lhs, ids=idfn)
@pytest.mark.parametrize('rhs', mod_mixed_decimals_rhs, ids=idfn)
def test_mod_fallback(lhs, rhs):
# See https://github.com/NVIDIA/spark-rapids/issues/8330
# Basically if we overflow on Decimal128 values when up-casting the operands, we need
# to fall back to CPU since we don't currently have enough precision to support that
# on the GPU.
@allow_non_gpu("ProjectExec", "Remainder")
@pytest.mark.skipif(not is_databricks113_or_later() and not is_spark_340_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/8330")
@pytest.mark.parametrize('lhs', [DecimalGen(38,0), DecimalGen(37,2), DecimalGen(38,5)], ids=idfn)
@pytest.mark.parametrize('rhs', [DecimalGen(27,7), DecimalGen(30,10), DecimalGen(38,1)], ids=idfn)
def test_mod_mixed_overflow_fallback(lhs, rhs):
assert_gpu_fallback_collect(
lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"a % b"), 'Remainder')
lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"a % b"), "Remainder")

# Split into 4 tests to permute https://github.com/NVIDIA/spark-rapids/issues/7553 failures
# @pytest.mark.parametrize('lhs', [byte_gen, short_gen, int_gen, long_gen, DecimalGen(6, 5),
# DecimalGen(6, 4), DecimalGen(5, 4), DecimalGen(5, 3), DecimalGen(4, 2), DecimalGen(3, -2),
# DecimalGen(16, 7), DecimalGen(19, 0), DecimalGen(30, 10)], ids=idfn)
# @pytest.mark.parametrize('rhs', [byte_gen, short_gen, int_gen, long_gen, DecimalGen(6, 3),
# DecimalGen(10, -2), DecimalGen(15, 3), DecimalGen(30, 12), DecimalGen(3, -3),
# DecimalGen(27, 7), DecimalGen(20, -3)], ids=idfn)
# def test_mod_mixed(lhs, rhs):
# assert_gpu_and_cpu_are_equal_collect(
# lambda spark : two_col_df(spark, lhs, rhs).selectExpr(f"a % b"))

@pytest.mark.parametrize('lhs', [byte_gen, short_gen, int_gen, long_gen], ids=idfn)
@pytest.mark.parametrize('rhs', [byte_gen, short_gen, int_gen, long_gen], ids=idfn)
def test_pmod_mixed_numeric(lhs, rhs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import com.nvidia.spark.rapids.GpuOverrides.expr

import org.apache.spark.sql.catalyst.expressions.{Divide, Expression, IntegralDivide, Multiply, Remainder}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{DecimalMultiplyChecks, GpuAnsi, GpuDecimalDivide, GpuDecimalMultiply, GpuDivide, GpuIntegralDecimalDivide, GpuIntegralDivide, GpuMultiply, GpuRemainder}
import org.apache.spark.sql.rapids.{DecimalMultiplyChecks, DecimalRemainderChecks, GpuAnsi, GpuDecimalDivide, GpuDecimalMultiply, GpuDecimalRemainder, GpuDivide, GpuIntegralDecimalDivide, GpuIntegralDivide, GpuMultiply, GpuRemainder}
import org.apache.spark.sql.types.DecimalType

object DecimalArithmeticOverrides {
Expand Down Expand Up @@ -99,21 +99,36 @@ object DecimalArithmeticOverrides {
}
}),

/**
* Because of https://github.com/NVIDIA/spark-rapids/issues/7595 we are not supporting
* Decimals for spark 3.4 and db 11.3. Once we do we should revert the changes made to the
* following tests test_mod, test_mod_mixed and test_mod_pmod_by_zero_not_ansi or we should
* just revert this commit
*/
expr[Remainder](
"Remainder or modulo",
revans2 marked this conversation as resolved.
Show resolved Hide resolved
ExprChecks.binaryProject(
TypeSig.gpuNumeric, TypeSig.cpuNumeric,
("lhs", TypeSig.integral + TypeSig.fp, TypeSig.cpuNumeric),
("rhs", TypeSig.integral + TypeSig.fp, TypeSig.cpuNumeric)),
("lhs", TypeSig.gpuNumeric, TypeSig.cpuNumeric),
("rhs", TypeSig.gpuNumeric, TypeSig.cpuNumeric)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would be nice to add in a psNote for DECIMAL128 here to explain what we don't fully support. This is really minor if we think we can get full remainder support in before we ship 23.06. (especially minor because we don't generate the support docs for anything but the oldest version of Spark that we support) Oh well....

(a, conf, p, r) => new BinaryExprMeta[Remainder](a, conf, p, r) {
// See https://github.com/NVIDIA/spark-rapids/issues/8330
// Basically if we overflow on Decimal128 values when up-casting the operands, we need
// to fall back to CPU since we don't currently have enough precision to support that
// on the GPU.
override def tagExprForGpu(): Unit = {
if (a.left.dataType.isInstanceOf[DecimalType] &&
a.right.dataType.isInstanceOf[DecimalType]) {
val lhsType = a.left.dataType.asInstanceOf[DecimalType]
val rhsType = a.right.dataType.asInstanceOf[DecimalType]
val needed = DecimalRemainderChecks.neededPrecision(lhsType, rhsType)
if (needed > DType.DECIMAL128_MAX_PRECISION) {
willNotWorkOnGpu(s"needed intermediate precision ($needed) will overflow " +
s"outside of the maximum available decimal128 precision")
}
}
Comment on lines +114 to +123
Copy link
Collaborator

@gerashegalov gerashegalov May 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (a.left.dataType.isInstanceOf[DecimalType] &&
a.right.dataType.isInstanceOf[DecimalType]) {
val lhsType = a.left.dataType.asInstanceOf[DecimalType]
val rhsType = a.right.dataType.asInstanceOf[DecimalType]
val needed = DecimalRemainderChecks.neededPrecision(lhsType, rhsType)
if (needed > DType.DECIMAL128_MAX_PRECISION) {
willNotWorkOnGpu(s"needed intermediate precision ($needed) will overflow " +
s"outside of the maximum available decimal128 precision")
}
}
(a.left.dataType, a.right.dataType) match {
case (lhsType: DecimalType, rhsType: DecimalType) =>
val needed = DecimalRemainderChecks.neededPrecision(lhsType, rhsType)
if (needed > DType.DECIMAL128_MAX_PRECISION) {
willNotWorkOnGpu(s"needed intermediate precision ($needed) will overflow " +
s"outside of the maximum available decimal128 precision")
}
case _ => ()
}

}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuRemainder(lhs, rhs)
if (lhs.dataType.isInstanceOf[DecimalType] && rhs.dataType.isInstanceOf[DecimalType]) {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
GpuDecimalRemainder(lhs, rhs)
} else {
GpuRemainder(lhs, rhs)
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ import ai.rapids.cudf._
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -199,6 +201,40 @@ case class GpuSubtract(

case class GpuRemainder(left: Expression, right: Expression)
extends GpuRemainderBase(left, right) {
assert(!left.dataType.isInstanceOf[DecimalType] ||
!right.dataType.isInstanceOf[DecimalType],
"DecimalType remainder need to be handled by GpuDecimalRemainder")
}

object DecimalRemainderChecks {
def neededScale(lhs: DecimalType, rhs: DecimalType): Int =
math.max(lhs.scale, rhs.scale)

// For Remainder, the operands need to have the same precision (for CUDF to the do the
// computation) *and* the same scale (to account for the part of the remainder < 1 in the output).
// This means that first start with the needed scale (in this case the max of the scales between
// the 2 operands), and then account for enough space (precision) to store the resulting value
// without overflow
def neededPrecision(lhs: DecimalType, rhs: DecimalType): Int =
math.max(lhs.precision - lhs.scale, rhs.precision - rhs.scale) + neededScale(lhs, rhs)

def intermediateArgPrecision(lhs: DecimalType, rhs: DecimalType): Int =
math.min(
revans2 marked this conversation as resolved.
Show resolved Hide resolved
neededPrecision(lhs, rhs),
DType.DECIMAL128_MAX_PRECISION)

def intermediateLhsRhsType(
lhs: DecimalType,
rhs: DecimalType): DecimalType = {
val precision = intermediateArgPrecision(lhs, rhs)
val scale = neededScale(lhs, rhs)
DecimalType(precision, scale)
}
}

case class GpuDecimalRemainder(left: Expression, right: Expression)
extends GpuRemainderBase(left, right) with Logging {

// scalastyle:off
// The formula follows Hive which is based on the SQL standard and MS SQL:
// https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
Expand All @@ -215,6 +251,59 @@ case class GpuRemainder(left: Expression, right: Expression)
DecimalType.bounded(resultPrecision, resultScale)
}
}

def decimalType: DecimalType = dataType match {
case DecimalType.Fixed(_, _) => dataType.asInstanceOf[DecimalType]
case LongType => DecimalType.LongDecimal
}

private[this] lazy val lhsType: DecimalType = DecimalUtil.asDecimalType(left.dataType)
private[this] lazy val rhsType: DecimalType = DecimalUtil.asDecimalType(right.dataType)

// This is the type that the LHS will be cast to. The precision will match the precision of
// the intermediate rhs (to make CUDF happy doing the divide), but the scale will be shifted
// enough so CUDF produces the desired output scale
private[this] lazy val intermediateLhsType =
DecimalRemainderChecks.intermediateLhsRhsType(lhsType, rhsType)

// This is the type that the RHS will be cast to. The precision will match the precision of the
// intermediate lhs (to make CUDF happy doing the divide), but the scale will be the same
// as the input RHS scale.
private[this] lazy val intermediateRhsType =
DecimalRemainderChecks.intermediateLhsRhsType(lhsType, rhsType)

private[this] def divByZeroFixes(rhs: ColumnVector): ColumnVector = {
if (failOnError) {
withResource(GpuDivModLike.makeZeroScalar(rhs.getType)) { zeroScalar =>
if (rhs.contains(zeroScalar)) {
throw RapidsErrorUtils.divByZeroError(origin)
}
}
rhs.incRefCount()
} else {
GpuDivModLike.replaceZeroWithNull(rhs)
}
}

override def columnarEval(batch: ColumnarBatch): Any = {
val castLhs = withResource(GpuExpressionsUtils.columnarEvalToColumn(left, batch)) { lhs =>
GpuCast.doCast(lhs.getBase, lhs.dataType(), intermediateLhsType, ansiMode = failOnError,
legacyCastToString = false, stringToDateAnsiModeEnabled = false)
}
withResource(castLhs) { castLhs =>
val castRhs = withResource(GpuExpressionsUtils.columnarEvalToColumn(right, batch)) { rhs =>
withResource(divByZeroFixes(rhs.getBase)) { fixed =>
GpuCast.doCast(fixed, rhs.dataType(), intermediateRhsType, ansiMode = failOnError,
legacyCastToString = false, stringToDateAnsiModeEnabled = false)
}
}
withResource(castRhs) { castRhs =>
GpuColumnVector.from(
castLhs.mod(castRhs, GpuColumnVector.getNonNestedRapidsType(dataType)),
dataType)
}
}
}
}

case class GpuPmod(
Expand Down Expand Up @@ -310,7 +399,8 @@ case class GpuDecimalMultiply(
case class GpuIntegralDivide(
left: Expression,
right: Expression) extends GpuIntegralDivideParent(left, right) {
assert(!left.dataType.isInstanceOf[DecimalType],
assert(!left.dataType.isInstanceOf[DecimalType] ||
!right.dataType.isInstanceOf[DecimalType],
"DecimalType integral divides need to be handled by GpuIntegralDecimalDivide")
}

Expand Down Expand Up @@ -343,7 +433,3 @@ case class GpuIntegralDecimalDivide(
DecimalType.bounded(intDig, 0)
}
}