-
Notifications
You must be signed in to change notification settings - Fork 461
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6b87b7f
commit 92e9224
Showing
8 changed files
with
244 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ struct Settings; | |
|
||
namespace ErrorCodes | ||
{ | ||
|
||
extern const int ILLEGAL_TYPE_OF_ARGUMENT; | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
148 changes: 148 additions & 0 deletions
148
cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxCountDistinctFunctionParser.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
|
||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <Parser/AggregateFunctionParser.h> | ||
#include <DataTypes/DataTypeNullable.h> | ||
#include <Poco/Logger.h> | ||
#include <Common/logger_useful.h> | ||
#include "DataTypes/DataTypeAggregateFunction.h" | ||
#include "DataTypes/DataTypesNumber.h" | ||
#include "Functions/FunctionHelpers.h" | ||
|
||
namespace DB | ||
{ | ||
namespace ErrorCodes | ||
{ | ||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; | ||
extern const int BAD_ARGUMENTS; | ||
} | ||
} | ||
|
||
namespace local_engine | ||
{ | ||
|
||
using namespace DB; | ||
|
||
struct ApproxCountDistinctNameStruct | ||
{ | ||
static constexpr auto spark_name = "approx_count_distinct"; | ||
static constexpr auto ch_name = "uniqHLLPP"; | ||
}; | ||
|
||
/// Spark approx_count_distinct(expr, relative_sd) = CH uniqHLLPP(relative_sd)(if(isNull(expr), null, sparkXxHash64(expr))) | ||
/// Spark approx_count_distinct(expr) = CH uniqHLLPP(0.05)(if(isNull(expr), null, sparkXxHash64(expr))) | ||
template <typename NameStruct> | ||
class AggregateFunctionParserApproxCountDistinct final : public AggregateFunctionParser | ||
{ | ||
public: | ||
static constexpr auto name = NameStruct::spark_name; | ||
|
||
AggregateFunctionParserApproxCountDistinct(ParserContextPtr parser_context_) : AggregateFunctionParser(parser_context_) { } | ||
~AggregateFunctionParserApproxCountDistinct() override = default; | ||
|
||
String getName() const override { return NameStruct::spark_name; } | ||
|
||
String getCHFunctionName(const CommonFunctionInfo &) const override { return NameStruct::ch_name; } | ||
|
||
String getCHFunctionName(DataTypes & types) const override | ||
{ | ||
/// Always invoked during second stage, the first argument is expr, the second argument is relative_sd. | ||
/// 1. Remove the second argument because types are used to create the aggregate function. | ||
/// 2. Replace the first argument type with UInt64 or Nullable(UInt64) because uniqHLLPP requres it. | ||
types.resize(1); | ||
const auto old_type = types[0]; | ||
types[0] = std::make_shared<DataTypeUInt64>(); | ||
if (old_type->isNullable()) | ||
types[0] = std::make_shared<DataTypeNullable>(types[0]); | ||
|
||
return NameStruct::ch_name; | ||
} | ||
|
||
Array parseFunctionParameters( | ||
const CommonFunctionInfo & func_info, ActionsDAG::NodeRawConstPtrs & arg_nodes, ActionsDAG & actions_dag) const override | ||
{ | ||
if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE | ||
|| func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT | ||
|| func_info.phase == substrait::AGGREGATION_PHASE_UNSPECIFIED) | ||
{ | ||
const auto & arguments = func_info.arguments; | ||
const size_t num_args = arguments.size(); | ||
const size_t num_nodes = arg_nodes.size(); | ||
if (num_args != num_nodes || num_args > 2 || num_args < 1 || num_nodes > 2 || num_nodes < 1) | ||
throw Exception( | ||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, | ||
"Function {} takes 1 or 2 arguments in phase {}", | ||
getName(), | ||
magic_enum::enum_name(func_info.phase)); | ||
|
||
Array params; | ||
if (num_args == 2) | ||
{ | ||
const auto & relative_sd_arg = arguments[1].value(); | ||
if (relative_sd_arg.has_literal()) | ||
{ | ||
auto [_, field] = parseLiteral(relative_sd_arg.literal()); | ||
params.push_back(std::move(field)); | ||
} | ||
else | ||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument of function {} must be literal", getName()); | ||
} | ||
else | ||
{ | ||
params.push_back(0.05); | ||
} | ||
|
||
const auto & expr_arg = arg_nodes[0]; | ||
const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {expr_arg}); | ||
const auto * hash_node = toFunctionNode(actions_dag, "sparkXxHash64", {expr_arg}); | ||
const auto * null_node | ||
= addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeNullable>(std::make_shared<DataTypeUInt64>()), {}); | ||
const auto * if_node = toFunctionNode(actions_dag, "if", {is_null_node, null_node, hash_node}); | ||
/// Replace the first argument expr with if(isNull(expr), null, sparkXxHash64(expr)) | ||
arg_nodes[0] = if_node; | ||
arg_nodes.resize(1); | ||
|
||
return params; | ||
} | ||
else | ||
{ | ||
if (arg_nodes.size() != 1) | ||
throw Exception( | ||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, | ||
"Function {} takes 1 argument in phase {}", | ||
getName(), | ||
magic_enum::enum_name(func_info.phase)); | ||
|
||
const auto & result_type = arg_nodes[0]->result_type; | ||
const auto * aggregate_function_type = checkAndGetDataType<DataTypeAggregateFunction>(result_type.get()); | ||
if (!aggregate_function_type) | ||
throw Exception( | ||
ErrorCodes::BAD_ARGUMENTS, | ||
"The first argument type of function {} in phase {} must be AggregateFunction, but is {}", | ||
getName(), | ||
magic_enum::enum_name(func_info.phase), | ||
result_type->getName()); | ||
|
||
return aggregate_function_type->getParameters(); | ||
} | ||
} | ||
|
||
Array getDefaultFunctionParameters() const override { return {0.05}; } | ||
}; | ||
|
||
static const AggregateFunctionParserRegister<AggregateFunctionParserApproxCountDistinct<ApproxCountDistinctNameStruct>> registerer_approx_count_distinct; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.