Skip to content

Commit

Permalink
apacheGH-43956: [C++][Compute] Add Decimal32/64 Casts
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroshade committed Dec 12, 2024
1 parent 5a042fc commit a31f916
Show file tree
Hide file tree
Showing 4 changed files with 1,387 additions and 31 deletions.
177 changes: 175 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,43 @@ struct DecimalConversions<Decimal256, InDecimal> {
static Decimal256 ConvertOutput(Decimal256&& val) { return val; }
};

template <typename InDecimal>
struct DecimalConversions<Decimal32, InDecimal> {
static Decimal32 ConvertInput(InDecimal&& val) { return Decimal32(val.low_bits()); }
static Decimal32 ConvertOutput(Decimal32&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal64, Decimal32> {
// Convert then scale
static Decimal64 ConvertInput(Decimal32&& val) { return Decimal64(val); }
static Decimal64 ConvertOutput(Decimal64&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal64, Decimal64> {
static Decimal64 ConvertInput(Decimal64&& val) { return val; }
static Decimal64 ConvertOutput(Decimal64&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal64, Decimal128> {
// Scale then truncate
static Decimal128 ConvertInput(Decimal128&& val) { return val; }
static Decimal64 ConvertOutput(Decimal128&& val) {
return Decimal64(static_cast<int64_t>(val.low_bits()));
}
};

template <>
struct DecimalConversions<Decimal64, Decimal256> {
// Scale then truncate
static Decimal256 ConvertInput(Decimal256&& val) { return val; }
static Decimal64 ConvertOutput(Decimal256&& val) {
return Decimal64(static_cast<int64_t>(val.low_bits()));
}
};

template <>
struct DecimalConversions<Decimal128, Decimal256> {
// Scale then truncate
Expand All @@ -495,6 +532,20 @@ struct DecimalConversions<Decimal128, Decimal128> {
static Decimal128 ConvertOutput(Decimal128&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal128, Decimal64> {
// convert then scale
static Decimal128 ConvertInput(Decimal64&& val) { return Decimal128(val.value()); }
static Decimal128 ConvertOutput(Decimal128&& val) { return val; }
};

template <>
struct DecimalConversions<Decimal128, Decimal32> {
// convert then scale
static Decimal128 ConvertInput(Decimal32&& val) { return Decimal128(val.value()); }
static Decimal128 ConvertOutput(Decimal128&& val) { return val; }
};

struct UnsafeUpscaleDecimal {
template <typename OutValue, typename Arg0Value>
OutValue Call(KernelContext*, Arg0Value val, Status*) const {
Expand Down Expand Up @@ -659,6 +710,18 @@ struct DecimalCastFunctor {
}
};

template <typename I>
struct CastFunctor<
Decimal32Type, I,
enable_if_t<is_base_binary_type<I>::value || is_binary_view_like_type<I>::value>>
: public DecimalCastFunctor<Decimal32Type, I> {};

template <typename I>
struct CastFunctor<
Decimal64Type, I,
enable_if_t<is_base_binary_type<I>::value || is_binary_view_like_type<I>::value>>
: public DecimalCastFunctor<Decimal64Type, I> {};

template <typename I>
struct CastFunctor<
Decimal128Type, I,
Expand Down Expand Up @@ -744,6 +807,10 @@ std::shared_ptr<CastFunction> GetCastToInteger(std::string name) {
// From decimal to integer
DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
CastFunctor<OutType, Decimal128Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, out_ty,
CastFunctor<OutType, Decimal32Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, out_ty,
CastFunctor<OutType, Decimal64Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
CastFunctor<OutType, Decimal256Type>::Exec));
return func;
Expand Down Expand Up @@ -772,6 +839,10 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
AddCommonNumberCasts<OutType>(out_ty, func.get());

// From decimal to floating point
DCHECK_OK(func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, out_ty,
CastFunctor<OutType, Decimal32Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, out_ty,
CastFunctor<OutType, Decimal64Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType(Type::DECIMAL)}, out_ty,
CastFunctor<OutType, Decimal128Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
Expand All @@ -780,6 +851,94 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
return func;
}

std::shared_ptr<CastFunction> GetCastToDecimal32() {
OutputType sig_out_ty(ResolveOutputFromOptions);

auto func = std::make_shared<CastFunction>("cast_decimal32", Type::DECIMAL32);
AddCommonCasts(Type::DECIMAL32, sig_out_ty, func.get());

// Cast from floating point
DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
CastFunctor<Decimal32Type, FloatType>::Exec));
DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
CastFunctor<Decimal32Type, DoubleType>::Exec));

// Cast from integer
for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
auto exec = GenerateInteger<CastFunctor, Decimal32Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}

// Cast from other strings
for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
auto exec = GenerateVarBinaryBase<CastFunctor, Decimal32Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}
for (const std::shared_ptr<DataType>& in_ty : BinaryViewTypes()) {
auto exec = GenerateVarBinaryViewBase<CastFunctor, Decimal32Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}

// Cast from other decimal
auto exec = CastFunctor<Decimal32Type, Decimal32Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, sig_out_ty, exec));
exec = CastFunctor<Decimal32Type, Decimal64Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, sig_out_ty, exec));
exec = CastFunctor<Decimal32Type, Decimal128Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
exec = CastFunctor<Decimal32Type, Decimal256Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
return func;
}

std::shared_ptr<CastFunction> GetCastToDecimal64() {
OutputType sig_out_ty(ResolveOutputFromOptions);

auto func = std::make_shared<CastFunction>("cast_decimal64", Type::DECIMAL64);
AddCommonCasts(Type::DECIMAL64, sig_out_ty, func.get());

// Cast from floating point
DCHECK_OK(func->AddKernel(Type::FLOAT, {float32()}, sig_out_ty,
CastFunctor<Decimal64Type, FloatType>::Exec));
DCHECK_OK(func->AddKernel(Type::DOUBLE, {float64()}, sig_out_ty,
CastFunctor<Decimal64Type, DoubleType>::Exec));

// Cast from integer
for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
auto exec = GenerateInteger<CastFunctor, Decimal64Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}

// Cast from other strings
for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
auto exec = GenerateVarBinaryBase<CastFunctor, Decimal64Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}
for (const std::shared_ptr<DataType>& in_ty : BinaryViewTypes()) {
auto exec = GenerateVarBinaryViewBase<CastFunctor, Decimal64Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
}

// Cast from other decimal
auto exec = CastFunctor<Decimal64Type, Decimal32Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, sig_out_ty, exec));
exec = CastFunctor<Decimal64Type, Decimal64Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, sig_out_ty, exec));
exec = CastFunctor<Decimal64Type, Decimal128Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
exec = CastFunctor<Decimal64Type, Decimal256Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, sig_out_ty, exec));
return func;
}

std::shared_ptr<CastFunction> GetCastToDecimal128() {
OutputType sig_out_ty(ResolveOutputFromOptions);

Expand Down Expand Up @@ -809,8 +968,14 @@ std::shared_ptr<CastFunction> GetCastToDecimal128() {
}

// Cast from other decimal
auto exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec;
auto exec = CastFunctor<Decimal128Type, Decimal32Type>::Exec;
// We resolve the output type of this kernel from the CastOptions
DCHECK_OK(
func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, sig_out_ty, exec));
exec = CastFunctor<Decimal128Type, Decimal64Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, sig_out_ty, exec));
exec = CastFunctor<Decimal128Type, Decimal128Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
exec = CastFunctor<Decimal128Type, Decimal256Type>::Exec;
Expand Down Expand Up @@ -848,7 +1013,13 @@ std::shared_ptr<CastFunction> GetCastToDecimal256() {
}

// Cast from other decimal
auto exec = CastFunctor<Decimal256Type, Decimal128Type>::Exec;
auto exec = CastFunctor<Decimal256Type, Decimal32Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL32, {InputType(Type::DECIMAL32)}, sig_out_ty, exec));
exec = CastFunctor<Decimal256Type, Decimal64Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL64, {InputType(Type::DECIMAL64)}, sig_out_ty, exec));
exec = CastFunctor<Decimal256Type, Decimal128Type>::Exec;
DCHECK_OK(
func->AddKernel(Type::DECIMAL128, {InputType(Type::DECIMAL128)}, sig_out_ty, exec));
exec = CastFunctor<Decimal256Type, Decimal256Type>::Exec;
Expand Down Expand Up @@ -950,6 +1121,8 @@ std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
auto cast_double = GetCastToFloating<DoubleType>("cast_double");
functions.push_back(cast_double);

functions.push_back(GetCastToDecimal32());
functions.push_back(GetCastToDecimal64());
functions.push_back(GetCastToDecimal128());
functions.push_back(GetCastToDecimal256());

Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_cast_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,8 @@ void AddNumberToStringCasts(CastFunction* func) {
template <typename OutType>
void AddDecimalToStringCasts(CastFunction* func) {
auto out_ty = TypeTraits<OutType>::type_singleton();
for (const auto& in_tid : std::vector<Type::type>{Type::DECIMAL128, Type::DECIMAL256}) {
for (const auto& in_tid : std::vector<Type::type>{Type::DECIMAL32, Type::DECIMAL64,
Type::DECIMAL128, Type::DECIMAL256}) {
DCHECK_OK(
func->AddKernel(in_tid, {in_tid}, out_ty,
GenerateDecimal<DecimalToStringCastFunctor, OutType>(in_tid),
Expand Down
Loading

0 comments on commit a31f916

Please sign in to comment.