Skip to content

Commit

Permalink
[fix](decimal256) support decimal256 for many functions (apache#42136)
Browse files Browse the repository at this point in the history
Issue Number: close #xxx

Support decimal256 for the following functions:
```
multi_distinct_sum
multi_distinct_count
array_sum
array_avg
array_product
array_cum_sum
```
  • Loading branch information
jacktengg committed Oct 24, 2024
1 parent 24f2c47 commit 9e2019a
Show file tree
Hide file tree
Showing 65 changed files with 953 additions and 215 deletions.
2 changes: 1 addition & 1 deletion be/src/runtime/runtime_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class RuntimeState {
_query_options.check_overflow_for_decimal;
}

bool enable_decima256() const {
bool enable_decimal256() const {
return _query_options.__isset.enable_decimal256 && _query_options.enable_decimal256;
}

Expand Down
4 changes: 4 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class Arena;
class IColumn;
class IDataType;

struct AggregateFunctionAttr {
bool enable_decimal256 {false};
};

template <bool nullable, typename ColVecType>
class AggregateFunctionBitmapCount;
template <typename Op>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_approx_count_distinct(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
WhichDataType which(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE, COLUMN_TYPE) \
Expand Down
15 changes: 12 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_avg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,17 @@ template <typename T>
using AggregateFuncAvgDecimal256 = typename AvgDecimal256<T>::Function;

void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("avg", creator_with_type::creator<AggregateFuncAvg>);
factory.register_function_both("avg_decimal256",
creator_with_type::creator<AggregateFuncAvgDecimal256>);
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (attr.enable_decimal256) {
return creator_with_type::creator<AggregateFuncAvgDecimal256>(name, types,
result_is_nullable, attr);
} else {
return creator_with_type::creator<AggregateFuncAvg>(name, types, result_is_nullable,
attr);
}
};
factory.register_function_both("avg", creator);
}
} // namespace doris::vectorized
9 changes: 5 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_type) {
return nullptr;
}

AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_bitmap_union_count(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return std::make_shared<AggregateFunctionBitmapCount<true, ColumnBitmap>>(argument_types);
Expand All @@ -53,7 +53,8 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const std::str

AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return AggregateFunctionPtr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_types)

AggregateFunctionPtr create_aggregate_function_bitmap_agg(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return AggregateFunctionPtr(create_with_int_data_type<true>(argument_types));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& n

AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() == 1) {
if (name == "array_agg") {
return create_aggregate_function_collect_impl<std::false_type, std::true_type>(
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_corr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ struct CorrMoment {

AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_binary(name, argument_types);
return create_with_two_basic_numeric_types<CorrMoment>(argument_types[0], argument_types[1],
argument_types, result_is_nullable);
Expand Down
9 changes: 5 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_count(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_at_most<1>(name, argument_types);

return std::make_shared<AggregateFunctionCount>(argument_types);
}

AggregateFunctionPtr create_aggregate_function_count_not_null_unary(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_count_not_null_unary(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_at_most<1>(name, argument_types);

return std::make_shared<AggregateFunctionCountNotNullUnary>(argument_types);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_count_by_enum(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() < 1) {
LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}",
argument_types.size(), name);
Expand Down
6 changes: 4 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_covar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ AggregateFunctionPtr create_function_single_value(const String& name,

AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionSamp, CovarSampName, SampData>(
name, argument_types, result_is_nullable, NOTNULLABLE);
}

AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionPop, CovarName, PopData>(
name, argument_types, result_is_nullable, NOTNULLABLE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ const std::string DISTINCT_FUNCTION_PREFIX = "multi_distinct_";

void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
// 1. we should get not nullable types;
DataTypes nested_types(types.size());
std::transform(types.begin(), types.end(), nested_types.begin(),
Expand All @@ -92,7 +93,7 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact
auto transform_arguments = function_combinator->transform_arguments(nested_types);
auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size());
auto nested_function = factory.get(nested_function_name, transform_arguments, false,
BeExecVersionManager::get_newest_version());
BeExecVersionManager::get_newest_version(), attr);
return function_combinator->transform_aggregate_function(nested_function, types,
result_is_nullable);
};
Expand Down
7 changes: 4 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function_foreach.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
namespace doris::vectorized {

void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable) -> AggregateFunctionPtr {
AggregateFunctionCreator creator =
[&](const std::string& name, const DataTypes& types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) -> AggregateFunctionPtr {
const std::string& suffix = AggregateFunctionForEach::AGG_FOREACH_SUFFIX;
DataTypes transform_arguments;
for (const auto& t : types) {
Expand All @@ -46,7 +47,7 @@ void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFacto
auto nested_function_name = name.substr(0, name.size() - suffix.size());
auto nested_function =
factory.get(nested_function_name, transform_arguments, result_is_nullable,
BeExecVersionManager::get_newest_version(), false);
BeExecVersionManager::get_newest_version(), attr);
if (!nested_function) {
throw Exception(
ErrorCode::INTERNAL_ERROR,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl
}

AggregateFunctionPtr create_aggregate_function_group_array_intersect(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_unary(name, argument_types);
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ const std::string AggregateFunctionGroupConcatImplStr::separator = ",";

AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() == 1) {
return creator_without_type::create<
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ AggregateFunctionPtr create_agg_function_histogram(const DataTypes& argument_typ

AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
WhichDataType type(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ AggregateFunctionPtr type_dispatch_for_aggregate_function_kurt(const DataTypes&

AggregateFunctionPtr create_aggregate_function_kurt(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() != 1) {
LOG(WARNING) << "aggregate function " << name << " requires exactly 1 argument";
return nullptr;
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ AggregateFunctionPtr create_agg_function_map_agg(const DataTypes& argument_types

AggregateFunctionPtr create_aggregate_function_map_agg(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
WhichDataType type(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace doris::vectorized {
template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_single_value(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_unary(name, argument_types);

AggregateFunctionPtr res(creator_with_numeric_type::create<AggregateFunctionsSingleValue, Data,
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_min_max.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,5 +714,6 @@ class AggregateFunctionsSingleValue final
template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_single_value(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable);
const bool result_is_nullable,
const AggregateFunctionAttr& attr = {});
} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ template <template <typename> class AggregateFunctionTemplate,
template <typename, typename> class Data>
AggregateFunctionPtr create_aggregate_function_min_max_by(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() != 2) {
return nullptr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ template <template <typename> class Impl>
AggregateFunctionPtr create_aggregate_function_orthogonal(const std::string& name,
const DataTypes& argument_types,

const bool result_is_nullable) {
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.empty()) {
LOG(WARNING) << "Incorrect number of arguments for aggregate function " << name;
return nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
AggregateFunctionPtr create_aggregate_function_percentile_approx(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
Expand All @@ -43,7 +43,8 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri
}

AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
Expand Down
Loading

0 comments on commit 9e2019a

Please sign in to comment.