Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test: Refine astToExecutor #5895

Merged
merged 14 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dbms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ add_headers_and_sources(dbms src/DataStreams)
add_headers_and_sources(dbms src/DataTypes)
add_headers_and_sources(dbms src/Databases)
add_headers_and_sources(dbms src/Debug)
add_headers_and_sources(dbms src/Debug/MockExecutor)
add_headers_and_sources(dbms src/Dictionaries)
add_headers_and_sources(dbms src/Dictionaries/Embedded)
add_headers_and_sources(dbms src/Dictionaries/Embedded/GeodataProviders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#include <AggregateFunctions/AggregateFunctionUniq.h>
#include <DataTypes/FieldToDataType.h>
#include <Debug/MockComputeServerManager.h>
#include <Debug/astToExecutor.h>
#include <Debug/MockExecutor/astToExecutor.h>
#include <Debug/MockExecutor/funcSigs.h>
#include <Flash/Coprocessor/ChunkCodec.h>
#include <Flash/Coprocessor/DAGCodec.h>
#include <Flash/Coprocessor/DAGUtils.h>
Expand All @@ -31,8 +32,16 @@

namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
extern const int NO_SUCH_COLUMN_IN_TABLE;
} // namespace ErrorCodes

using ASTPartitionByElement = ASTOrderByElement;
using MockComputeServerManager = tests::MockComputeServerManager;

void literalFieldToTiPBExpr(const ColumnInfo & ci, const Field & val_field, tipb::Expr * expr, Int32 collator_id)
{
*(expr->mutable_field_type()) = columnInfoToFieldType(ci);
Expand Down Expand Up @@ -120,96 +129,6 @@ void literalFieldToTiPBExpr(const ColumnInfo & ci, const Field & val_field, tipb
}
}

namespace
{
std::unordered_map<String, tipb::ScalarFuncSig> func_name_to_sig({
{"plusint", tipb::ScalarFuncSig::PlusInt},
{"minusint", tipb::ScalarFuncSig::MinusInt},
{"equals", tipb::ScalarFuncSig::EQInt},
{"notEquals", tipb::ScalarFuncSig::NEInt},
{"and", tipb::ScalarFuncSig::LogicalAnd},
{"or", tipb::ScalarFuncSig::LogicalOr},
{"xor", tipb::ScalarFuncSig::LogicalXor},
{"not", tipb::ScalarFuncSig::UnaryNotInt},
{"greater", tipb::ScalarFuncSig::GTInt},
{"greaterorequals", tipb::ScalarFuncSig::GEInt},
{"less", tipb::ScalarFuncSig::LTInt},
{"lessorequals", tipb::ScalarFuncSig::LEInt},
{"in", tipb::ScalarFuncSig::InInt},
{"notin", tipb::ScalarFuncSig::InInt},
{"date_format", tipb::ScalarFuncSig::DateFormatSig},
{"if", tipb::ScalarFuncSig::IfInt},
{"from_unixtime", tipb::ScalarFuncSig::FromUnixTime2Arg},
/// bit_and/bit_or/bit_xor is aggregated function in clickhouse/mysql
{"bitand", tipb::ScalarFuncSig::BitAndSig},
{"bitor", tipb::ScalarFuncSig::BitOrSig},
{"bitxor", tipb::ScalarFuncSig::BitXorSig},
{"bitnot", tipb::ScalarFuncSig::BitNegSig},
{"notequals", tipb::ScalarFuncSig::NEInt},
{"like", tipb::ScalarFuncSig::LikeSig},
{"cast_int_int", tipb::ScalarFuncSig::CastIntAsInt},
{"cast_int_real", tipb::ScalarFuncSig::CastIntAsReal},
{"cast_real_int", tipb::ScalarFuncSig::CastRealAsInt},
{"cast_real_real", tipb::ScalarFuncSig::CastRealAsReal},
{"cast_decimal_int", tipb::ScalarFuncSig::CastDecimalAsInt},
{"cast_time_int", tipb::ScalarFuncSig::CastTimeAsInt},
{"cast_string_int", tipb::ScalarFuncSig::CastStringAsInt},
{"cast_int_decimal", tipb::ScalarFuncSig::CastIntAsDecimal},
{"cast_real_decimal", tipb::ScalarFuncSig::CastRealAsDecimal},
{"cast_decimal_decimal", tipb::ScalarFuncSig::CastDecimalAsDecimal},
{"cast_time_decimal", tipb::ScalarFuncSig::CastTimeAsDecimal},
{"cast_string_decimal", tipb::ScalarFuncSig::CastStringAsDecimal},
{"cast_int_string", tipb::ScalarFuncSig::CastIntAsString},
{"cast_real_string", tipb::ScalarFuncSig::CastRealAsString},
{"cast_decimal_string", tipb::ScalarFuncSig::CastDecimalAsString},
{"cast_time_string", tipb::ScalarFuncSig::CastTimeAsString},
{"cast_string_string", tipb::ScalarFuncSig::CastStringAsString},
{"cast_int_date", tipb::ScalarFuncSig::CastIntAsTime},
{"cast_real_date", tipb::ScalarFuncSig::CastRealAsTime},
{"cast_decimal_date", tipb::ScalarFuncSig::CastDecimalAsTime},
{"cast_time_date", tipb::ScalarFuncSig::CastTimeAsTime},
{"cast_string_date", tipb::ScalarFuncSig::CastStringAsTime},
{"cast_int_datetime", tipb::ScalarFuncSig::CastIntAsTime},
{"cast_real_datetime", tipb::ScalarFuncSig::CastRealAsTime},
{"cast_decimal_datetime", tipb::ScalarFuncSig::CastDecimalAsTime},
{"cast_time_datetime", tipb::ScalarFuncSig::CastTimeAsTime},
{"cast_string_datetime", tipb::ScalarFuncSig::CastStringAsTime},
{"concat", tipb::ScalarFuncSig::Concat},
{"round_int", tipb::ScalarFuncSig::RoundInt},
{"round_uint", tipb::ScalarFuncSig::RoundInt},
{"round_dec", tipb::ScalarFuncSig::RoundDec},
{"round_real", tipb::ScalarFuncSig::RoundReal},
{"round_with_frac_int", tipb::ScalarFuncSig::RoundWithFracInt},
{"round_with_frac_uint", tipb::ScalarFuncSig::RoundWithFracInt},
{"round_with_frac_dec", tipb::ScalarFuncSig::RoundWithFracDec},
{"round_with_frac_real", tipb::ScalarFuncSig::RoundWithFracReal},
});

std::unordered_map<String, tipb::ExprType> agg_func_name_to_sig({
{"min", tipb::ExprType::Min},
{"max", tipb::ExprType::Max},
{"count", tipb::ExprType::Count},
{"sum", tipb::ExprType::Sum},
{"first_row", tipb::ExprType::First},
{"uniqRawRes", tipb::ExprType::ApproxCountDistinct},
{"group_concat", tipb::ExprType::GroupConcat},
});

std::unordered_map<String, tipb::ExprType> window_func_name_to_sig({
{"RowNumber", tipb::ExprType::RowNumber},
{"Rank", tipb::ExprType::Rank},
{"DenseRank", tipb::ExprType::DenseRank},
{"Lead", tipb::ExprType::Lead},
{"Lag", tipb::ExprType::Lag},
});

DAGColumnInfo toNullableDAGColumnInfo(const DAGColumnInfo & input)
{
DAGColumnInfo output = input;
output.second.clearNotNullFlag();
return output;
}

void literalToPB(tipb::Expr * expr, const Field & value, int32_t collator_id)
{
DataTypePtr type = applyVisitor(FieldToDataType(), value);
Expand Down Expand Up @@ -311,19 +230,6 @@ void astToPB(const DAGSchema & input, ASTPtr ast, tipb::Expr * expr, int32_t col
}
}

auto checkSchema(const DAGSchema & input, String checked_column)
{
auto ft = std::find_if(input.begin(), input.end(), [&](const auto & field) {
auto [checked_db_name, checked_table_name, checked_column_name] = splitQualifiedName(checked_column);
auto [db_name, table_name, column_name] = splitQualifiedName(field.first);
if (checked_table_name.empty())
return column_name == checked_column_name;
else
return table_name == checked_table_name && column_name == checked_column_name;
});
return ft;
}

void functionToPB(const DAGSchema & input, ASTFunction * func, tipb::Expr * expr, int32_t collator_id, const Context & context)
{
/// aggregation function is handled in Aggregation, so just treated as a column
Expand All @@ -345,8 +251,8 @@ void functionToPB(const DAGSchema & input, ASTFunction * func, tipb::Expr * expr
// TODO: Support more functions.
// TODO: Support type inference.

const auto it_sig = func_name_to_sig.find(func_name_lowercase);
if (it_sig == func_name_to_sig.end())
const auto it_sig = tests::func_name_to_sig.find(func_name_lowercase);
if (it_sig == tests::func_name_to_sig.end())
{
throw Exception("Unsupported function: " + func_name_lowercase, ErrorCodes::LOGICAL_ERROR);
}
Expand Down Expand Up @@ -617,8 +523,8 @@ TiDB::ColumnInfo compileExpr(const DAGSchema & input, ASTPtr ast)
{
/// check function
String func_name_lowercase = Poco::toLower(func->name);
const auto it_sig = func_name_to_sig.find(func_name_lowercase);
if (it_sig == func_name_to_sig.end())
const auto it_sig = tests::func_name_to_sig.find(func_name_lowercase);
if (it_sig == tests::func_name_to_sig.end())
{
throw Exception("Unsupported function: " + func_name_lowercase, ErrorCodes::LOGICAL_ERROR);
}
Expand Down Expand Up @@ -789,42 +695,6 @@ void compileFilter(const DAGSchema & input, ASTPtr ast, std::vector<ASTPtr> & co
conditions.push_back(ast);
compileExpr(input, ast);
}
} // namespace

namespace Debug
{
String LOCAL_HOST = "127.0.0.1:3930";

void setServiceAddr(const std::string & addr)
{
LOCAL_HOST = addr;
}
} // namespace Debug

ColumnName splitQualifiedName(const String & s)
{
ColumnName ret;
Poco::StringTokenizer string_tokens(s, ".");

switch (string_tokens.count())
{
case 1:
ret.column_name = s;
break;
case 2:
ret.table_name = string_tokens[0];
ret.column_name = string_tokens[1];
break;
case 3:
ret.db_name = string_tokens[0];
ret.table_name = string_tokens[1];
ret.column_name = string_tokens[2];
break;
default:
throw Exception("Invalid identifier name " + s);
}
return ret;
}

namespace mock
{
Expand Down Expand Up @@ -1038,8 +908,8 @@ bool Aggregation::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collato
tipb::Expr * arg_expr = agg_func->add_children();
astToPB(input_schema, arg, arg_expr, collator_id, context);
}
auto agg_sig_it = agg_func_name_to_sig.find(func->name);
if (agg_sig_it == agg_func_name_to_sig.end())
auto agg_sig_it = tests::agg_func_name_to_sig.find(func->name);
if (agg_sig_it == tests::agg_func_name_to_sig.end())
throw Exception("Unsupported agg function " + func->name, ErrorCodes::LOGICAL_ERROR);
auto agg_sig = agg_sig_it->second;
agg_func->set_tp(agg_sig);
Expand Down Expand Up @@ -1453,8 +1323,8 @@ bool Window::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator_id,
tipb::Expr * func = window_expr->add_children();
astToPB(input_schema, arg, func, collator_id, context);
}
auto window_sig_it = window_func_name_to_sig.find(window_func->name);
if (window_sig_it == window_func_name_to_sig.end())
auto window_sig_it = tests::window_func_name_to_sig.find(window_func->name);
if (window_sig_it == tests::window_func_name_to_sig.end())
throw Exception(fmt::format("Unsupported window function {}", window_func->name), ErrorCodes::LOGICAL_ERROR);
auto window_sig = window_sig_it->second;
window_expr->set_tp(window_sig);
Expand Down Expand Up @@ -1905,7 +1775,7 @@ ExecutorPtr compileWindow(ExecutorPtr input, size_t & executor_index, ASTPtr fun
}
// TODO: add more window functions
TiDB::ColumnInfo ci;
switch (window_func_name_to_sig[func->name])
switch (tests::window_func_name_to_sig[func->name])
{
case tipb::ExprType::RowNumber:
case tipb::ExprType::Rank:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <Debug/DAGProperties.h>
#include <Debug/DBGInvoker.h>
#include <Debug/MockExecutor/astToExecutorUtils.h>
#include <Debug/MockServerInfo.h>
#include <Debug/MockTiDB.h>
#include <Functions/FunctionFactory.h>
Expand All @@ -35,34 +36,6 @@

namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
extern const int NO_SUCH_COLUMN_IN_TABLE;
} // namespace ErrorCodes

using DAGColumnInfo = std::pair<String, ColumnInfo>;
using DAGSchema = std::vector<DAGColumnInfo>;

namespace Debug
{
extern String LOCAL_HOST;
void setServiceAddr(const std::string & addr);
} // namespace Debug

// We use qualified format like "db_name.table_name.column_name"
// to identify one column of a table.
// We can split the qualified format into the ColumnName struct.
struct ColumnName
{
String db_name;
String table_name;
String column_name;
};

ColumnName splitQualifiedName(const String & s);

struct MPPCtx
{
Timestamp start_ts;
Expand Down
74 changes: 74 additions & 0 deletions dbms/src/Debug/MockExecutor/astToExecutorUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2022 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Debug/MockExecutor/astToExecutorUtils.h>

namespace DB
{
ColumnName splitQualifiedName(const String & s)
{
ColumnName ret;
Poco::StringTokenizer string_tokens(s, ".");

switch (string_tokens.count())
{
case 1:
ret.column_name = s;
break;
case 2:
ret.table_name = string_tokens[0];
ret.column_name = string_tokens[1];
break;
case 3:
ret.db_name = string_tokens[0];
ret.table_name = string_tokens[1];
ret.column_name = string_tokens[2];
break;
default:
throw Exception("Invalid identifier name " + s);
}
return ret;
}


std::__wrap_iter<const std::pair<std::string, TiDB::ColumnInfo> *> checkSchema(const DAGSchema & input, String checked_column)
{
auto ft = std::find_if(input.begin(), input.end(), [&](const auto & field) {
auto [checked_db_name, checked_table_name, checked_column_name] = splitQualifiedName(checked_column);
auto [db_name, table_name, column_name] = splitQualifiedName(field.first);
if (checked_table_name.empty())
return column_name == checked_column_name;
else
return table_name == checked_table_name && column_name == checked_column_name;
});
return ft;
}

DAGColumnInfo toNullableDAGColumnInfo(const DAGColumnInfo & input)
{
DAGColumnInfo output = input;
output.second.clearNotNullFlag();
return output;
}

namespace Debug
{
String LOCAL_HOST = "127.0.0.1:3930";

void setServiceAddr(const std::string & addr)
{
LOCAL_HOST = addr;
}
} // namespace Debug
} // namespace DB
Loading