diff --git a/src/main/cpp/src/DecimalUtilsJni.cpp b/src/main/cpp/src/DecimalUtilsJni.cpp index f732276817..6c7c1cc781 100644 --- a/src/main/cpp/src/DecimalUtilsJni.cpp +++ b/src/main/cpp/src/DecimalUtilsJni.cpp @@ -19,8 +19,13 @@ extern "C" { -JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_DecimalUtils_multiply128( - JNIEnv* env, jclass, jlong j_view_a, jlong j_view_b, jint j_product_scale) +JNIEXPORT jlongArray JNICALL +Java_com_nvidia_spark_rapids_jni_DecimalUtils_multiply128(JNIEnv* env, + jclass, + jlong j_view_a, + jlong j_view_b, + jint j_product_scale, + bool cast_interim_result) { JNI_NULL_CHECK(env, j_view_a, "column is null", 0); JNI_NULL_CHECK(env, j_view_b, "column is null", 0); @@ -30,7 +35,7 @@ JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_DecimalUtils_multi auto view_b = reinterpret_cast(j_view_b); auto scale = static_cast(j_product_scale); return cudf::jni::convert_table_for_return( - env, cudf::jni::multiply_decimal128(*view_a, *view_b, scale)); + env, cudf::jni::multiply_decimal128(*view_a, *view_b, scale, cast_interim_result)); } CATCH_STD(env, 0); } diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index 392fb495b4..92273ff545 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -657,14 +657,16 @@ struct dec128_multiplier { dec128_multiplier(bool* overflows, cudf::mutable_column_view const& product_view, cudf::column_view const& a_col, - cudf::column_view const& b_col) + cudf::column_view const& b_col, + bool const cast_interim_result) : overflows(overflows), a_data(a_col.data<__int128_t>()), b_data(b_col.data<__int128_t>()), product_data(product_view.data<__int128_t>()), a_scale(a_col.type().scale()), b_scale(b_col.type().scale()), - prod_scale(product_view.type().scale()) + prod_scale(product_view.type().scale()), + cast_interim_result(cast_interim_result) { } @@ -675,22 +677,24 @@ struct dec128_multiplier { chunked256 product = multiply(a, b); - // Spark does some really odd things that I personally think are a bug - // https://issues.apache.org/jira/browse/SPARK-40129 - // But to match Spark we need to first round the result to a precision of 38 - // and this is specific to the value in the result of the multiply. - // Then we need to round the result to the final scale that we care about. - int dec_precision = precision10(product); - int first_div_precision = dec_precision - 38; - - int mult_scale = a_scale + b_scale; - if (first_div_precision > 0) { - auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); - product = divide_and_round(product, first_div_scale_divisor); - - // a_scale and b_scale are negative. first_div_precision is not - mult_scale = a_scale + b_scale + first_div_precision; - } + int const mult_scale = [&]() { + // According to https://issues.apache.org/jira/browse/SPARK-40129 + // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in + // versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 The bug is fixed for later versions but to + // match the legacy behavior we need to first round the result to a precision of 38 then we + // need to round the result to the final scale that we care about. + if (cast_interim_result) { + auto const first_div_precision = precision10(product) - 38; + if (first_div_precision > 0) { + auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); + product = divide_and_round(product, first_div_scale_divisor); + + // a_scale and b_scale are negative. first_div_precision is not + return a_scale + b_scale + first_div_precision; + } + } + return a_scale + b_scale; + }(); int exponent = prod_scale - mult_scale; if (exponent < 0) { @@ -718,6 +722,7 @@ struct dec128_multiplier { private: // output column for overflow detected bool* const overflows; + bool const cast_interim_result; // input data for multiply __int128_t const* const a_data; @@ -968,6 +973,7 @@ namespace cudf::jni { std::unique_ptr multiply_decimal128(cudf::column_view const& a, cudf::column_view const& b, int32_t product_scale, + bool const cast_interim_result, rmm::cuda_stream_view stream) { CUDF_EXPECTS(a.type().id() == cudf::type_id::DECIMAL128, "not a DECIMAL128 column"); @@ -992,10 +998,11 @@ std::unique_ptr multiply_decimal128(cudf::column_view const& a, auto overflows_view = columns[0]->mutable_view(); auto product_view = columns[1]->mutable_view(); check_scale_divisor(a.type().scale() + b.type().scale(), product_scale); - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - dec128_multiplier(overflows_view.begin(), product_view, a, b)); + thrust::for_each( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + dec128_multiplier(overflows_view.begin(), product_view, a, b, cast_interim_result)); return std::make_unique(std::move(columns)); } diff --git a/src/main/cpp/src/decimal_utils.hpp b/src/main/cpp/src/decimal_utils.hpp index 95c6c56c3d..9793e63445 100644 --- a/src/main/cpp/src/decimal_utils.hpp +++ b/src/main/cpp/src/decimal_utils.hpp @@ -30,6 +30,7 @@ std::unique_ptr multiply_decimal128( cudf::column_view const& a, cudf::column_view const& b, int32_t product_scale, + bool const cast_interim_result, rmm::cuda_stream_view stream = cudf::get_default_stream()); std::unique_ptr divide_decimal128( diff --git a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java index 389679965a..17337691c5 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java @@ -25,21 +25,50 @@ public class DecimalUtils { NativeDepsLoader.loadNativeDeps(); } + /** + * Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified + * scale with overflow detection. This method considers a precision greater than 38 as overflow + * even if the number still fits in a 128-bit representation. + * + * WARNING: This method has a bug which we match with Spark versions before 3.4.2, + * 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10: + * -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166 + * while the actual answer based on Java BigDecimal is 102401338377036577293248132533.575165 + * + * @param a factor input, must match row count of the other factor input + * @param b factor input, must match row count of the other factor input + * @param productScale scale to use for the product type + * @return table containing a boolean column and a DECIMAL128 product column of the specified + * scale. The boolean value will be true if an overflow was detected for that row's + * DECIMAL128 product value. A null input row will result in a corresponding null output + * row. + */ + public static Table multiply128(ColumnView a, ColumnView b, int productScale) { + return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, true)); + } /** * Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified * scale with overflow detection. This method considers a precision greater than 38 as overflow * even if the number still fits in a 128-bit representation. + * + * WARNING: With interimCast set to true, this method has a bug which we match with Spark versions before 3.4.2, + * 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10: + * -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166 + * while the actual answer based on Java BigDecimal is 102401338377036577293248132533.575165 + * * @param a factor input, must match row count of the other factor input * @param b factor input, must match row count of the other factor input * @param productScale scale to use for the product type + * @param interimCast whether to cast the result of the division to 38 precision before casting it again to the final + * precision * @return table containing a boolean column and a DECIMAL128 product column of the specified * scale. The boolean value will be true if an overflow was detected for that row's * DECIMAL128 product value. A null input row will result in a corresponding null output * row. */ - public static Table multiply128(ColumnView a, ColumnView b, int productScale) { - return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale)); + public static Table multiply128(ColumnView a, ColumnView b, int productScale, boolean interimCast) { + return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, interimCast)); } /** @@ -148,7 +177,7 @@ public static Table add128(ColumnView a, ColumnView b, int targetScale) { return new Table(add128(a.getNativeView(), b.getNativeView(), targetScale)); } - private static native long[] multiply128(long viewA, long viewB, int productScale); + private static native long[] multiply128(long viewA, long viewB, int productScale, boolean interimCast); private static native long[] divide128(long viewA, long viewB, int quotientScale, boolean isIntegerDivide); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java index 4698855f31..7f3079e825 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java @@ -86,6 +86,18 @@ void simplePosMultiplyZeroByNegOne() { } } + @Test + void multiply128WithoutInterimCast() { + try (ColumnVector lhs = makeDec128Column("-8533444864753048107770677711.1312637916"); + ColumnVector rhs = makeDec128Column("-12.0000000000"); + ColumnVector expectedBasic = makeDec128Column("102401338377036577293248132533.575165"); + ColumnVector expectedValid = ColumnVector.fromBooleans(false); + Table found = DecimalUtils.multiply128(lhs, rhs, -6, false)) { + assertColumnsAreEqual(expectedValid, found.getColumn(0)); + assertColumnsAreEqual(expectedBasic, found.getColumn(1)); + } + } + @Test void largePosMultiplyTenByTen() { try (ColumnVector lhs =