Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement kernel for casting float to decimal #2078

Merged
merged 36 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2658d8c
WIP for float-to-decimal
ttnghia May 27, 2024
d082fbd
Cleanup
ttnghia May 27, 2024
4e474e4
Working prototype
ttnghia May 28, 2024
aab75cb
Fix corner cases
ttnghia May 29, 2024
9d70962
Fix precision
ttnghia May 29, 2024
20ca7d4
Remove redundant code
ttnghia May 30, 2024
e4a1d5d
Using float to decimal code from cudf
ttnghia Jun 6, 2024
e36ab57
Merge branch 'branch-24.08' into float_to_decimal
ttnghia Jul 15, 2024
d6b813a
Revert "Using float to decimal code from cudf"
ttnghia Jul 15, 2024
c14b199
Use new conversion code
ttnghia Jul 16, 2024
73aaffe
Revert "Use new conversion code"
ttnghia Jul 16, 2024
6f39e4c
Merge branch 'branch-24.08' into float_to_decimal
ttnghia Jul 19, 2024
41e035f
Update Java class
ttnghia Jul 19, 2024
bde2c56
Update JNI
ttnghia Jul 19, 2024
faca763
Update docs
ttnghia Jul 19, 2024
06a4d6e
Cleanup
ttnghia Jul 19, 2024
d69451f
Merge branch 'branch-24.08' into float_to_decimal
ttnghia Jul 19, 2024
f1667ab
Remove comment
ttnghia Jul 19, 2024
be84fb8
Change varaible name
ttnghia Jul 19, 2024
f5c988e
Add Java test
ttnghia Jul 20, 2024
0d22af8
Change order of qualifier
ttnghia Jul 20, 2024
30b80e0
Add Java test
ttnghia Jul 20, 2024
90492aa
Revert unrelated changes
ttnghia Jul 20, 2024
06b6f1c
Change format
ttnghia Jul 20, 2024
0a4aa8e
Update copyright year
ttnghia Jul 20, 2024
90df388
Fix error message
ttnghia Jul 20, 2024
6b2ceae
Fix docs
ttnghia Jul 20, 2024
59f300c
Fix copyright year
ttnghia Jul 20, 2024
0888b1e
Fix docs
ttnghia Jul 20, 2024
ce3aa17
Probably final float-to-decimal
ttnghia Jul 23, 2024
86bcc87
Cleanup
ttnghia Jul 24, 2024
797ac1b
Cleanup
ttnghia Jul 24, 2024
a408c74
Merge branch 'branch-24.08' into float_to_decimal
ttnghia Jul 24, 2024
bac7919
Merge branch 'branch-24.08' into float_to_decimal
ttnghia Jul 25, 2024
f32465e
Fix typo
ttnghia Jul 25, 2024
d80b9f6
Revert changes in `DecimalUtilsTest.java`
ttnghia Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/main/cpp/src/DecimalUtilsJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -110,4 +110,24 @@ JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_DecimalUtils_subtr
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_DecimalUtils_floatingPointToDecimal(
JNIEnv* env, jclass, jlong j_input, jint output_type_id, jint precision, jint decimal_scale)
{
JNI_NULL_CHECK(env, j_input, "j_input is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::column_view const*>(j_input);
cudf::jni::native_jlongArray output(env, 2);

auto [casted_col, has_failure] = cudf::jni::floating_point_to_decimal(
*input,
cudf::data_type{static_cast<cudf::type_id>(output_type_id), static_cast<int>(decimal_scale)},
precision);
output[0] = cudf::jni::release_as_jlong(std::move(casted_col));
output[1] = static_cast<jlong>(has_failure);
return output.get_jArray();
}
CATCH_STD(env, 0);
}

} // extern "C"
261 changes: 261 additions & 0 deletions src/main/cpp/src/decimal_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,23 @@
*/

#include "decimal_utils.hpp"
#include "jni_utils.hpp"

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/valid_if.cuh>
#include <cudf/fixed_point/floating_conversion.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

#include <rmm/device_scalar.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/functional.h>
#include <thrust/tabulate.h>

#include <cmath>
#include <cstddef>

Expand Down Expand Up @@ -1172,4 +1181,256 @@ std::unique_ptr<cudf::table> sub_decimal128(cudf::column_view const& a,
dec128_sub(overflows_view.begin<bool>(), sub_view, a, b));
return std::make_unique<cudf::table>(std::move(columns));
}

namespace {

using namespace numeric;
using namespace numeric::detail;

/**
* @brief Perform floating-point to integer decimal conversion, matching Spark behavior.
*
* The desired decimal value is computed as (returned_value * 10^{-pow10}).
*
* The rounding and precision decisions made here are chosen to match Apache Spark.
* Spark wants to perform the conversion as double to have the most precision.
* However, the behavior is still slightly different if the original type was float.
*
* @tparam FloatType The type of floating-point value we are converting from
* @tparam IntType The type of integer we are converting to, to store the decimal value
*
* @param input The floating point value to convert
* @param pow10 The power of 10 to scale the floating-point value by
* @return Integer representation of the floating-point value, rounding after scaled
*/
template <typename FloatType,
typename IntType,
CUDF_ENABLE_IF(cuda::std::is_floating_point_v<FloatType>)>
__device__ inline IntType scaled_round(FloatType input, int32_t pow10)
{
// Extract components of the (double-ized) floating point number
using converter = floating_converter<double>;
auto const integer_rep = converter::bit_cast_to_integer(static_cast<double>(input));
if (converter::is_zero(integer_rep)) { return 0; }

// Note that the significand here is an unsigned integer with sizeof(double)
auto const is_negative = converter::get_is_negative(integer_rep);
auto const [significand, floating_pow2] = converter::get_significand_and_pow2(integer_rep);

auto const unsigned_floating = (input < 0) ? -input : input;
auto const rounding_wont_overflow = [&] {
auto const scale_factor = static_cast<double>(
multiply_power10<IntType>(cuda::std::make_unsigned_t<IntType>{1}, -pow10));
return 10.0 * static_cast<double>(unsigned_floating) * scale_factor <
static_cast<double>(cuda::std::numeric_limits<IntType>::max());
}();

// Spark often wants to round the last decimal place, so we'll perform the conversion
// with one lower power of 10 so that we can (optionally) round at the end.
// Note that we can't round this way if we've requested the minimum power.
bool const can_round = cuda::std::is_same_v<IntType, __int128_t> ? rounding_wont_overflow : true;
auto const shifting_pow10 = can_round ? pow10 - 1 : pow10;

// Sometimes add half a bit to correct for compiler rounding to nearest floating-point value.
// See comments in add_half_if_truncates(), with differences detailed below.
// Even if we don't add the bit, shift bits to line up with what the shifting algorithm is
// expecting.
bool const is_whole_number = cuda::std::floor(input) == input;
auto const [base2_value, pow2] = [is_whole_number](auto significand, auto floating_pow2) {
if constexpr (cuda::std::is_same_v<FloatType, double>) {
// Add the 1/2 bit regardless of truncation, but still not for whole numbers.
auto const base2_value =
(significand << 1) + static_cast<decltype(significand)>(!is_whole_number);
return cuda::std::make_pair(base2_value, floating_pow2 - 1);
} else {
// Input was float: never add 1/2 bit.
// Why? Because we converted to double, and the 1/2 bit beyond float is WAY too large compared
// to double's precision. And the 1/2 bit beyond double is not due to user input.
return cuda::std::make_pair(significand << 1, floating_pow2 - 1);
}
}(significand, floating_pow2);

// Main algorithm: Apply the powers of 2 and 10 (except for the last power-of-10).
// Use larger intermediate type for conversion to avoid overflow for last power-of-10.
using intermediate_type =
cuda::std::conditional_t<cuda::std::is_same_v<IntType, std::int32_t>, std::int64_t, __int128_t>;
cuda::std::make_unsigned_t<intermediate_type> magnitude =
[&, base2_value = base2_value, pow2 = pow2] {
if constexpr (cuda::std::is_same_v<IntType, std::int32_t>) {
return rounding_wont_overflow ? convert_floating_to_integral_shifting<IntType, double>(
base2_value, shifting_pow10, pow2)
: convert_floating_to_integral_shifting<std::int64_t, double>(
base2_value, shifting_pow10, pow2);
} else {
return convert_floating_to_integral_shifting<__int128_t, double>(
base2_value, shifting_pow10, pow2);
}
}();

// Spark wants to floor the last digits of the output, clearing data that was beyond the
// precision that was available in double.

// How many digits do we need to floor?
// From the decimal digit corresponding to pow2 (just past double precision) to the end (pow10).
int const floor_pow10 = [&](int pow2_bit) {
// The conversion from pow2 to pow10 is log10(2), which is ~ 90/299 (close enough for ints)
// But Spark chooses the rougher 3/10 ratio instead of 90/299.
if constexpr (cuda::std::is_same_v<FloatType, float>) {
return (3 * pow2_bit - 10 * pow10) / 10;
} else {
// Spark rounds up the power-of-10 to floor for DOUBLES >= 2^63 (and yes, this is the exact
// cutoff).
bool const round_up = unsigned_floating > std::numeric_limits<std::int64_t>::max();
return (3 * pow2_bit - 10 * pow10 + 9 * round_up) / 10;
}
}(pow2);

// Floor end digits
if (can_round) {
if (floor_pow10 < 0) {
// Truncated: The scale factor cut off the extra, imprecise bits.
// To round to the final decimal place, add 5 to one past the last decimal place.
magnitude += 5U;
magnitude /= 10U; // Apply the last power of 10
} else {
// We are keeping decimal digits with data beyond the precision of double.
// We want to truncate these digits, but sometimes we want to round first.
// We will round if and only if we didn't already add a half-bit earlier.
if constexpr (cuda::std::is_same_v<FloatType, double>) {
// For doubles, only round the extra digits of whole numbers.
// If it was not a whole number, we already added 1/2 a bit at higher precision than this
// earlier.
if (is_whole_number) {
magnitude += multiply_power10<IntType>(decltype(magnitude)(5), floor_pow10);
}
} else {
// Input was float: we didn't add a half-bit earlier, so round at the edge of precision
// here.
magnitude += multiply_power10<IntType>(decltype(magnitude)(5), floor_pow10);
}

// +1: Divide the last power-of-10 that we postponed earlier to do rounding.
auto const truncated = divide_power10<IntType>(magnitude, floor_pow10 + 1);
magnitude = multiply_power10<IntType>(truncated, floor_pow10);
}
} else if (floor_pow10 > 0) {
auto const truncated = divide_power10<IntType>(magnitude, floor_pow10);
magnitude = multiply_power10<IntType>(truncated, floor_pow10);
}

// Reapply the sign and return.
// NOTE: Cast can overflow!
auto const signed_magnitude = static_cast<IntType>(magnitude);
return is_negative ? -signed_magnitude : signed_magnitude;
}

template <typename FloatType, typename DecimalRepType>
struct floating_point_to_decimal_fn {
cudf::column_device_view input;
int8_t* validity;
bool* has_failure;
int32_t decimal_places;
DecimalRepType exclusive_bound;

__device__ DecimalRepType operator()(cudf::size_type idx) const
{
auto const x = input.element<FloatType>(idx);

if (input.is_null(idx) || !std::isfinite(x)) {
if (!std::isfinite(x)) { *has_failure = true; }
validity[idx] = false;
return DecimalRepType{0};
}

auto const scaled_rounded = scaled_round<FloatType, DecimalRepType>(x, -decimal_places);
auto const is_out_of_bound =
-exclusive_bound >= scaled_rounded || scaled_rounded >= exclusive_bound;
if (is_out_of_bound) { *has_failure = true; }
validity[idx] = !is_out_of_bound;

return is_out_of_bound ? DecimalRepType{0} : scaled_rounded;
}
};

struct floating_point_to_decimal_dispatcher {
template <typename FloatType, typename DecimalType>
static constexpr bool supported_types()
{
return (std::is_same_v<FloatType, float> || //
std::is_same_v<FloatType, double>)&& //
(std::is_same_v<DecimalType, numeric::decimal32> ||
std::is_same_v<DecimalType, numeric::decimal64> ||
std::is_same_v<DecimalType, numeric::decimal128>);
}

template <typename FloatType,
typename DecimalType,
typename... Args,
CUDF_ENABLE_IF(not supported_types<FloatType, DecimalType>())>
void operator()(Args...) const
{
CUDF_FAIL("Unsupported types for floating_point_to_decimal_fn", cudf::data_type_error);
}

template <typename FloatType,
typename DecimalType,
CUDF_ENABLE_IF(supported_types<FloatType, DecimalType>())>
void operator()(cudf::column_view const& input,
cudf::mutable_column_view const& output,
int8_t* validity,
bool* has_failure,
int32_t decimal_places,
int32_t precision,
rmm::cuda_stream_view stream) const
{
using DecimalRepType = cudf::device_storage_type_t<DecimalType>;

auto const d_input_ptr = cudf::column_device_view::create(input, stream);
auto const exclusive_bound = static_cast<DecimalRepType>(
multiply_power10<DecimalRepType>(cuda::std::make_unsigned_t<DecimalRepType>{1}, precision));

thrust::tabulate(rmm::exec_policy_nosync(stream),
output.begin<DecimalRepType>(),
output.end<DecimalRepType>(),
floating_point_to_decimal_fn<FloatType, DecimalRepType>{
*d_input_ptr, validity, has_failure, decimal_places, exclusive_bound});
}
};

} // namespace

std::pair<std::unique_ptr<cudf::column>, bool> floating_point_to_decimal(
cudf::column_view const& input,
cudf::data_type output_type,
int32_t precision,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto output = cudf::make_fixed_point_column(
output_type, input.size(), cudf::mask_state::UNALLOCATED, stream, mr);

auto const decimal_places = -output_type.scale();
auto const default_mr = rmm::mr::get_current_device_resource();

rmm::device_uvector<int8_t> validity(input.size(), stream, default_mr);
rmm::device_scalar<bool> has_failure(false, stream, default_mr);

cudf::double_type_dispatcher(input.type(),
output_type,
floating_point_to_decimal_dispatcher{},
input,
output->mutable_view(),
validity.begin(),
has_failure.data(),
decimal_places,
precision,
stream);

auto [null_mask, null_count] =
cudf::detail::valid_if(validity.begin(), validity.end(), thrust::identity{}, stream, mr);
if (null_count > 0) { output->set_null_mask(std::move(null_mask), null_count); }

return {std::move(output), has_failure.value(stream)};
}

} // namespace cudf::jni
21 changes: 20 additions & 1 deletion src/main/cpp/src/decimal_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,4 +62,23 @@ std::unique_ptr<cudf::table> sub_decimal128(
cudf::column_view const& b,
int32_t quotient_scale,
rmm::cuda_stream_view stream = cudf::get_default_stream());

/**
* @brief Cast floating point values to decimals, matching the behavior of Spark.
*
* @param input The input column, which is either FLOAT32 or FLOAT64 type
* @param output_type The output decimal type
* @param precision The maximum number of digits that will be preserved in the output
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return A cudf column containing the cast result and a boolean value indicating whether the cast
operation has failed for any input rows
*/
std::pair<std::unique_ptr<cudf::column>, bool> floating_point_to_decimal(
cudf::column_view const& input,
cudf::data_type output_type,
int32_t precision,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

} // namespace cudf::jni
Loading
Loading