Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
khwilson committed Sep 23, 2024
1 parent 0078523 commit 57e655c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
36 changes: 20 additions & 16 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,14 @@ struct ProductImpl : public ScalarAggregator {

Status Finalize(KernelContext*, Datum* out) override {
std::shared_ptr<DataType> out_type_;
if (auto decimal128_type = std::dynamic_pointer_cast<Decimal128Type>(this->out_type)) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, decimal128_type->scale()));
} else if (auto decimal256_type = std::dynamic_pointer_cast<Decimal256Type>(this->out_type)) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, decimal256_type->scale()));
if (out_type->id() == Type::DECIMAL128) {
auto cast_type = checked_cast<Decimal128Type>(out_type);
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision,
cast_type->scale()));
} else if (out_type->id() == Type::DECIMAL256) {
auto cast_type = checked_cast<Decimal256Type>(out_type);
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision,
cast_type->scale()));
} else {
out_type_ = out_type;
}
Expand Down Expand Up @@ -1057,10 +1061,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
func = std::make_shared<ScalarAggregateFunction>("sum", Arity::Unary(), sum_doc,
&default_scalar_aggregate_options);
AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get());
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType), SumInit, func.get(),
SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType), SumInit, func.get(),
SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType),
SumInit, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType),
SumInit, func.get(), SimdLevel::NONE);
AddArrayScalarAggKernels(SumInit, SignedIntTypes(), int64(), func.get());
AddArrayScalarAggKernels(SumInit, UnsignedIntTypes(), uint64(), func.get());
AddArrayScalarAggKernels(SumInit, FloatingPointTypes(), float64(), func.get());
Expand All @@ -1085,10 +1089,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
&default_scalar_aggregate_options);
AddArrayScalarAggKernels(MeanInit, {boolean()}, float64(), func.get());
AddArrayScalarAggKernels(MeanInit, NumericTypes(), float64(), func.get());
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType), MeanInit, func.get(),
SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType), MeanInit, func.get(),
SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType),
MeanInit, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType),
MeanInit, func.get(), SimdLevel::NONE);
AddArrayScalarAggKernels(MeanInit, {null()}, float64(), func.get());
// Add the SIMD variants for mean
#if defined(ARROW_HAVE_RUNTIME_AVX2)
Expand Down Expand Up @@ -1169,10 +1173,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
AddArrayScalarAggKernels(ProductInit::Init, UnsignedIntTypes(), uint64(), func.get());
AddArrayScalarAggKernels(ProductInit::Init, FloatingPointTypes(), float64(),
func.get());
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType), ProductInit::Init,
func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType), ProductInit::Init,
func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType),
ProductInit::Init, func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType),
ProductInit::Init, func.get(), SimdLevel::NONE);
AddArrayScalarAggKernels(ProductInit::Init, {null()}, int64(), func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));

Expand Down
12 changes: 8 additions & 4 deletions cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ struct SumImpl : public ScalarAggregator {
std::shared_ptr<DataType> out_type_;
if (out_type->id() == Type::DECIMAL128) {
auto cast_type = checked_cast<Decimal128Type>(out_type);
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, cast_type->scale()));
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision,
cast_type->scale()));
} else if (out_type->id() == Type::DECIMAL256) {
auto cast_type = checked_cast<Decimal256Type>(out_type);
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, cast_type->scale()));
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision,
cast_type->scale()));
} else {
out_type_ = out_type;
}
Expand Down Expand Up @@ -233,9 +235,11 @@ struct MeanImpl<ArrowType, SimdLevel, enable_if_decimal<ArrowType>>
std::shared_ptr<DataType> out_type_;
auto decimal_type = checked_cast<DecimalType>(this->out_type);
if (decimal_type->id() == Type::DECIMAL128) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, decimal_type->scale()));
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision,
decimal_type->scale()));
} else {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, decimal_type->scale()));
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision,
decimal_type->scale()));
}

if ((!options.skip_nulls && this->nulls_observed) ||
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,22 @@ Result<TypeHolder> ListValuesType(KernelContext* ctx,
return value_type;
}

Result<TypeHolder> MaxPrecisionDecimalType(KernelContext*, const std::vector<TypeHolder>& args) {
Result<TypeHolder> MaxPrecisionDecimalType(KernelContext*,
const std::vector<TypeHolder>& args) {
std::shared_ptr<DataType> out_type_;
auto type_id = args[0].type->id();
if (type_id == Type::DECIMAL128 || type_id == Type::DECIMAL256) {
auto base_type_ = checked_cast<const DecimalType*>(args[0].type);
if (type_id == Type::DECIMAL128) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, base_type_->scale()));
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision,
base_type_->scale()));
} else {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, base_type_->scale()));
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision,
base_type_->scale()));
}
} else {
return Status::TypeError("A call to MaxPrecisionDecimalType was made with a non-DecimalType");
return Status::TypeError(
"A call to MaxPrecisionDecimalType was made with a non-DecimalType");
}
return out_type_;
}
Expand Down

0 comments on commit 57e655c

Please sign in to comment.