Skip to content

Commit

Permalink
Adopt changes from JNI for casting from float to decimal (#10917)
Browse files Browse the repository at this point in the history
* Debugging

Signed-off-by: Nghia Truong <[email protected]>

* Add test

Signed-off-by: Nghia Truong <[email protected]>

* Adopt changes from JNI

Signed-off-by: Nghia Truong <[email protected]>

* WIP

Signed-off-by: Nghia Truong <[email protected]>

* Enable SparkUT test

Signed-off-by: Nghia Truong <[email protected]>

* Reduce test threshold

Signed-off-by: Nghia Truong <[email protected]>

* Add unit tests

Signed-off-by: Nghia Truong <[email protected]>

* Cleanup

Signed-off-by: Nghia Truong <[email protected]>

* Update python tests

Signed-off-by: Nghia Truong <[email protected]>

* Update unit tests

Signed-off-by: Nghia Truong <[email protected]>

* cast float to decimal: print number of failures

Signed-off-by: Nghia Truong <[email protected]>

* Change relative error

Signed-off-by: Nghia Truong <[email protected]>

* Revert "cast float to decimal: print number of failures"

This reverts commit c185206.

# Conflicts:
#	tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala

* Revert "Enable SparkUT test"

This reverts commit 93422e8.

* Change issue number

Signed-off-by: Nghia Truong <[email protected]>

---------

Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia authored Jul 30, 2024
1 parent 4f7589a commit c9f1ab9
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 114 deletions.
23 changes: 21 additions & 2 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_cast_string_timestamp_fallback():
decimal_gen_32bit,
pytest.param(decimal_gen_32bit_neg_scale, marks=
pytest.mark.skipif(is_dataproc_serverless_runtime(),
reason="Dataproc Serverless does not support negative scale for Decimal cast")),
reason="Dataproc Serverless does not support negative scale for Decimal cast")),
DecimalGen(precision=7, scale=7),
decimal_gen_64bit, decimal_gen_128bit, DecimalGen(precision=30, scale=2),
DecimalGen(precision=36, scale=5), DecimalGen(precision=38, scale=0),
Expand Down Expand Up @@ -265,6 +265,25 @@ def test_cast_long_to_decimal_overflow():
lambda spark : unary_op_df(spark, long_gen).select(
f.col('a').cast(DecimalType(18, -1))))


_float_special_cases = [(float("inf"), 5.0), (float("-inf"), 5.0), (float("nan"), 5.0)]
@pytest.mark.parametrize('data_gen', [FloatGen(special_cases=_float_special_cases),
DoubleGen(special_cases=_float_special_cases)],
ids=idfn)
@pytest.mark.parametrize('to_type', [
DecimalType(7, 1),
DecimalType(9, 9),
DecimalType(15, 2),
DecimalType(15, 15),
DecimalType(30, 3),
DecimalType(5, -3),
DecimalType(3, 0)], ids=idfn)
def test_cast_floating_point_to_decimal(data_gen, to_type):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a'), f.col('a').cast(to_type)),
conf={'spark.rapids.sql.castFloatToDecimal.enabled': 'true'})

# casting these types to string should be passed
basic_gens_for_cast_to_string = [ByteGen, ShortGen, IntegerGen, LongGen, StringGen, BooleanGen, DateGen, TimestampGen]
basic_array_struct_gens_for_cast_to_string = [f() for f in basic_gens_for_cast_to_string] + [null_gen] + decimal_gens
Expand Down Expand Up @@ -310,7 +329,7 @@ def test_cast_array_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
{"spark.sql.legacy.castComplexTypesToString.enabled": legacy})

def test_cast_float_to_string():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, FloatGen()).selectExpr("cast(cast(a as string) as float)"),
Expand Down
117 changes: 11 additions & 106 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import java.util.Optional

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DecimalUtils, DType, RegexProgram, Scalar}
import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DType, RegexProgram, Scalar}
import ai.rapids.cudf
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.{CastStrings, GpuTimeZoneDB}
import com.nvidia.spark.rapids.jni.{CastStrings, DecimalUtils, GpuTimeZoneDB}
import com.nvidia.spark.rapids.shims.{AnsiUtil, GpuCastShims, GpuIntervalUtils, GpuTypeShims, SparkShimImpl, YearParseUtil}
import org.apache.commons.text.StringEscapeUtils

Expand Down Expand Up @@ -192,7 +192,7 @@ object CastOptions {
val ARITH_ANSI_OPTIONS = new CastOptions(false, true, false)
val TO_PRETTY_STRING_OPTIONS = ToPrettyStringOptions

def getArithmeticCastOptions(failOnError: Boolean): CastOptions =
def getArithmeticCastOptions(failOnError: Boolean): CastOptions =
if (failOnError) ARITH_ANSI_OPTIONS else DEFAULT_CAST_OPTIONS

object ToPrettyStringOptions extends CastOptions(false, false, false,
Expand Down Expand Up @@ -628,7 +628,7 @@ object GpuCast {
case (TimestampType, DateType) if options.timeZoneId.isDefined =>
val zoneId = DateTimeUtils.getZoneId(options.timeZoneId.get)
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.asInstanceOf[ColumnVector],
zoneId.normalized())) {
zoneId.normalized())) {
shifted => shifted.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType))
}
case _ =>
Expand Down Expand Up @@ -696,49 +696,6 @@ object GpuCast {
}
}

/**
* Detects outlier values of a column given with specific range, and replaces them with
* a inputted substitution value.
*
* @param values ColumnVector to be performed with range check
* @param minValue Named parameter for function to create Scalar representing range minimum value
* @param maxValue Named parameter for function to create Scalar representing range maximum value
* @param replaceValue Named parameter for function to create scalar to substitute outlier value
* @param inclusiveMin Whether the min value is included in the valid range or not
* @param inclusiveMax Whether the max value is included in the valid range or not
*/
private def replaceOutOfRangeValues(values: ColumnView,
minValue: => Scalar,
maxValue: => Scalar,
replaceValue: => Scalar,
inclusiveMin: Boolean,
inclusiveMax: Boolean): ColumnVector = {

withResource(minValue) { minValue =>
withResource(maxValue) { maxValue =>
val minPredicate = if (inclusiveMin) {
values.lessThan(minValue)
} else {
values.lessOrEqualTo(minValue)
}
withResource(minPredicate) { minPredicate =>
val maxPredicate = if (inclusiveMax) {
values.greaterThan(maxValue)
} else {
values.greaterOrEqualTo(maxValue)
}
withResource(maxPredicate) { maxPredicate =>
withResource(maxPredicate.or(minPredicate)) { rangePredicate =>
withResource(replaceValue) { nullScalar =>
rangePredicate.ifElse(nullScalar, values)
}
}
}
}
}
}
}

def castToString(
input: ColumnView,
fromDataType: DataType, options: CastOptions): ColumnVector = fromDataType match {
Expand Down Expand Up @@ -1638,65 +1595,13 @@ object GpuCast {
input: ColumnView,
dt: DecimalType,
ansiMode: Boolean): ColumnVector = {

// Approach to minimize difference between CPUCast and GPUCast:
// step 1. cast input to FLOAT64 (if necessary)
// step 2. cast FLOAT64 to container DECIMAL (who keeps one more digit for rounding)
// step 3. perform HALF_UP rounding on container DECIMAL
val checkedInput = withResource(input.castTo(DType.FLOAT64)) { double =>
val roundedDouble = double.round(dt.scale, cudf.RoundMode.HALF_UP)
withResource(roundedDouble) { rounded =>
// We rely on containerDecimal to perform preciser rounding. So, we have to take extra
// space cost of container into consideration when we run bound check.
val containerScaleBound = DType.DECIMAL128_MAX_PRECISION - (dt.scale + 1)
val bound = math.pow(10, (dt.precision - dt.scale) min containerScaleBound)
if (ansiMode) {
assertValuesInRange[Double](rounded,
minValue = -bound,
maxValue = bound,
inclusiveMin = false,
inclusiveMax = false)
rounded.incRefCount()
} else {
replaceOutOfRangeValues(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false,
replaceValue = Scalar.fromNull(DType.FLOAT64))
}
}
}

withResource(checkedInput) { checked =>
val targetType = DecimalUtil.createCudfDecimal(dt)
// If target scale reaches DECIMAL128_MAX_PRECISION, container DECIMAL can not
// be created because of precision overflow. In this case, we perform casting op directly.
val casted = if (DType.DECIMAL128_MAX_PRECISION == dt.scale) {
checked.castTo(targetType)
} else {
// Increase precision by one along with scale in case of overflow, which may lead to
// the upcast of cuDF decimal type. If precision already hits the max precision, it is safe
// to increase the scale solely because we have checked and replaced out of range values.
val containerType = DecimalUtils.createDecimalType(
dt.precision + 1 min DType.DECIMAL128_MAX_PRECISION, dt.scale + 1)
withResource(checked.castTo(containerType)) { container =>
withResource(container.round(dt.scale, cudf.RoundMode.HALF_UP)) { rd =>
// The cast here is for cases that cuDF decimal type got promoted as precision + 1.
// Need to convert back to original cuDF type, to keep align with the precision.
rd.castTo(targetType)
}
}
}
// Cast NaN values to nulls
withResource(casted) { casted =>
withResource(input.isNan) { inputIsNan =>
withResource(Scalar.fromNull(targetType)) { nullScalar =>
inputIsNan.ifElse(nullScalar, casted)
}
}
}
val targetType = DecimalUtil.createCudfDecimal(dt)
val converted = DecimalUtils.floatingPointToDecimal(input, targetType, dt.precision)
if (ansiMode && converted.hasFailure) {
converted.result.close()
throw RapidsErrorUtils.arithmeticOverflowError(OVERFLOW_MESSAGE)
}
converted.result
}

def fixDecimalBounds(input: ColumnView,
Expand Down Expand Up @@ -1901,4 +1806,4 @@ case class GpuCast(

override def doColumnar(input: GpuColumnVector): ColumnVector =
doCast(input.getBase, input.dataType(), dataType, options)
}
}
44 changes: 40 additions & 4 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}

private def compareFloatToStringResults(float: Boolean, fromCpu: Array[Row],
private def compareFloatToStringResults(float: Boolean, fromCpu: Array[Row],
fromGpu: Array[Row]): Unit = {
fromCpu.zip(fromGpu).foreach {
case (c, g) =>
Expand Down Expand Up @@ -438,12 +438,12 @@ class CastOpSuite extends GpuExpressionTestSuite {
}

test("cast float to string") {
testCastToString[Float](DataTypes.FloatType, comparisonFunc =
testCastToString[Float](DataTypes.FloatType, comparisonFunc =
Some(compareStringifiedFloats(true)))
}

test("cast double to string") {
testCastToString[Double](DataTypes.DoubleType, comparisonFunc =
testCastToString[Double](DataTypes.DoubleType, comparisonFunc =
Some(compareStringifiedFloats(false)))
}

Expand Down Expand Up @@ -693,6 +693,11 @@ class CastOpSuite extends GpuExpressionTestSuite {
List(-10, -1, 0, 1, 10).foreach { scale =>
testCastToDecimal(DataTypes.FloatType, scale,
customDataGenerator = Some(floatsIncludeNaNs))
assertThrows[Throwable] {
testCastToDecimal(DataTypes.FloatType, scale,
customDataGenerator = Some(floatsIncludeNaNs),
ansiEnabled = true)
}
}
}

Expand All @@ -710,6 +715,11 @@ class CastOpSuite extends GpuExpressionTestSuite {
List(-10, -1, 0, 1, 10).foreach { scale =>
testCastToDecimal(DataTypes.DoubleType, scale,
customDataGenerator = Some(doublesIncludeNaNs))
assertThrows[Throwable] {
testCastToDecimal(DataTypes.DoubleType, scale,
customDataGenerator = Some(doublesIncludeNaNs),
ansiEnabled = true)
}
}
}

Expand All @@ -729,6 +739,32 @@ class CastOpSuite extends GpuExpressionTestSuite {
customDataGenerator = Option(genDoubles))
}

test("cast float/double to decimal (borderline value rounding)") {
val genFloats_12_7: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(3527.61953125f))).selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.FloatType, precision = 12, scale = 7,
customDataGenerator = Option(genFloats_12_7))

val genDoubles_12_7: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(3527.61953125))).selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.DoubleType, precision = 12, scale = 7,
customDataGenerator = Option(genDoubles_12_7))

val genFloats_3_1: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(9.95f))).selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.FloatType, precision = 3, scale = 1,
customDataGenerator = Option(genFloats_3_1))

val genDoubles_3_1: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(9.95))).selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.DoubleType, precision = 3, scale = 1,
customDataGenerator = Option(genDoubles_3_1))
}

test("cast decimal to decimal") {
// fromScale == toScale
testCastToDecimal(DataTypes.createDecimalType(18, 0),
Expand Down Expand Up @@ -967,7 +1003,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
dataType: DataType,
scale: Int,
precision: Int = ai.rapids.cudf.DType.DECIMAL128_MAX_PRECISION,
floatEpsilon: Double = 1e-9,
floatEpsilon: Double = 1e-14,
customDataGenerator: Option[SparkSession => DataFrame] = None,
customRandGenerator: Option[scala.util.Random] = None,
ansiEnabled: Boolean = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,7 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll {
-9223183700000000000L
).toDF("longs")
}

def datesPostEpochDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class RapidsTestSettings extends BackendTestSettings {
.exclude("SPARK-35719: cast timestamp with local time zone to timestamp without timezone", WONT_FIX_ISSUE("https://issues.apache.org/jira/browse/SPARK-40851"))
.exclude("SPARK-35112: Cast string to day-time interval", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10980"))
.exclude("SPARK-35735: Take into account day-time interval fields in cast", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10980"))
.exclude("casting to fixed-precision decimals", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10809"))
.exclude("casting to fixed-precision decimals", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11250"))
.exclude("SPARK-32828: cast from a derived user-defined type to a base type", WONT_FIX_ISSUE("User-defined types are not supported"))
.exclude("cast string to timestamp", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/blob/main/docs/compatibility.md#string-to-timestamp"))
.exclude("cast string to date", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
Expand Down

0 comments on commit c9f1ab9

Please sign in to comment.