Skip to content

Commit

Permalink
fix c++ tests
Browse files Browse the repository at this point in the history
  • Loading branch information
khwilson committed Oct 13, 2024
1 parent 12a2749 commit 01e53b3
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 74 deletions.
22 changes: 19 additions & 3 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,10 @@ struct ProductImpl : public ScalarAggregator {
}

Status Finalize(KernelContext*, Datum* out) override {
std::shared_ptr<DataType> out_type_;
ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type));
std::shared_ptr<DataType> out_type_ = this->out_type;
if (is_decimal(this->out_type->id())) {
ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type));
}

if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
Expand Down Expand Up @@ -1051,6 +1053,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
func = std::make_shared<ScalarAggregateFunction>("sum", Arity::Unary(), sum_doc,
&default_scalar_aggregate_options);
AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get());
AddAggKernel(KernelSignature::Make({Type::DECIMAL32}, MaxPrecisionDecimalType),
SumInit, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL64}, MaxPrecisionDecimalType),
SumInit, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType),
SumInit, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType),
Expand Down Expand Up @@ -1079,6 +1085,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
&default_scalar_aggregate_options);
AddArrayScalarAggKernels(MeanInit, {boolean()}, float64(), func.get());
AddArrayScalarAggKernels(MeanInit, NumericTypes(), float64(), func.get());
AddAggKernel(KernelSignature::Make({Type::DECIMAL32}, MaxPrecisionDecimalType),
MeanInit, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL64}, MaxPrecisionDecimalType),
MeanInit, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType),
MeanInit, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType),
Expand Down Expand Up @@ -1128,6 +1138,8 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
AddMinMaxKernels(MinMaxInitDefault, BaseBinaryTypes(), func.get());
AddMinMaxKernel(MinMaxInitDefault, Type::FIXED_SIZE_BINARY, func.get());
AddMinMaxKernel(MinMaxInitDefault, Type::INTERVAL_MONTHS, func.get());
AddMinMaxKernel(MinMaxInitDefault, Type::DECIMAL32, func.get());
AddMinMaxKernel(MinMaxInitDefault, Type::DECIMAL64, func.get());
AddMinMaxKernel(MinMaxInitDefault, Type::DECIMAL128, func.get());
AddMinMaxKernel(MinMaxInitDefault, Type::DECIMAL256, func.get());
// Add the SIMD variants for min max
Expand Down Expand Up @@ -1163,6 +1175,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
AddArrayScalarAggKernels(ProductInit::Init, UnsignedIntTypes(), uint64(), func.get());
AddArrayScalarAggKernels(ProductInit::Init, FloatingPointTypes(), float64(),
func.get());
AddAggKernel(KernelSignature::Make({Type::DECIMAL32}, MaxPrecisionDecimalType),
ProductInit::Init, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL64}, MaxPrecisionDecimalType),
ProductInit::Init, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType),
ProductInit::Init, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType),
Expand All @@ -1188,7 +1204,7 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
AddBasicAggKernels(IndexInit::Init, PrimitiveTypes(), int64(), func.get());
AddBasicAggKernels(IndexInit::Init, TemporalTypes(), int64(), func.get());
AddBasicAggKernels(IndexInit::Init,
{fixed_size_binary(1), decimal128(1, 0), decimal256(1, 0), null()},
{fixed_size_binary(1), decimal32(1, 0), decimal64(1, 0), decimal128(1, 0), decimal256(1, 0), null()},
int64(), func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));
}
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ struct SumImpl : public ScalarAggregator {
}

Status Finalize(KernelContext*, Datum* out) override {
std::shared_ptr<DataType> out_type_;
ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type));
std::shared_ptr<DataType> out_type_ = this->out_type;
if (is_decimal(this->out_type->id())) {
ARROW_ASSIGN_OR_RAISE(out_type_, WidenDecimalToMaxPrecision(this->out_type));
}

if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
Expand Down
Loading

0 comments on commit 01e53b3

Please sign in to comment.