diff --git a/hybridse/examples/toydb/src/tablet/tablet_catalog.h b/hybridse/examples/toydb/src/tablet/tablet_catalog.h index 3ea97d325b6..08be85f2568 100644 --- a/hybridse/examples/toydb/src/tablet/tablet_catalog.h +++ b/hybridse/examples/toydb/src/tablet/tablet_catalog.h @@ -221,21 +221,18 @@ class TabletCatalog : public vm::Catalog { bool AddTable(std::shared_ptr table); - std::shared_ptr GetDatabase(const std::string& db); + std::shared_ptr GetDatabase(const std::string& db) override; + + std::shared_ptr GetTable(const std::string& db, const std::string& table_name) override; - std::shared_ptr GetTable(const std::string& db, - const std::string& table_name); bool IndexSupport() override; - std::vector GetAggrTables( - const std::string& base_db, - const std::string& base_table, - const std::string& aggr_func, - const std::string& aggr_col, - const std::string& partition_cols, - const std::string& order_col) override { - vm::AggrTableInfo info = {"aggr_" + base_table, "aggr_db", base_db, base_table, - aggr_func, aggr_col, partition_cols, order_col, "1000"}; + std::vector GetAggrTables(const std::string& base_db, const std::string& base_table, + const std::string& aggr_func, const std::string& aggr_col, + const std::string& partition_cols, const std::string& order_col, + const std::string& filter_col) override { + vm::AggrTableInfo info = {"aggr_" + base_table, "aggr_db", base_db, base_table, aggr_func, aggr_col, + partition_cols, order_col, "1000", filter_col}; return {info}; } diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index 5a0d7459ee6..66e75f2e220 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -24,6 +24,7 @@ #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "boost/algorithm/string.hpp" @@ -1076,6 +1077,33 @@ class ConstNode : public ExprNode { } } + // include 'udf/literal_traits.h' for Nullable lead to recursive include + // so `optional` is used for nullable info + template + absl::StatusOr> GetAs() const { + if (IsNull()) { + return std::nullopt; + } + + if constexpr (std::is_same_v) { + return GetBool(); + } else if constexpr(std::is_same_v) { + return GetAsInt16(); + } else if constexpr (std::is_same_v) { + return GetAsInt32(); + } else if constexpr (std::is_same_v) { + return GetAsInt64(); + } else if constexpr (std::is_same_v) { + return GetAsFloat(); + } else if constexpr (std::is_same_v) { + return GetAsDouble(); + } else if constexpr (std::is_same_v) { + return GetAsString(); + } else { + return absl::InvalidArgumentError("can't cast as T"); + } + } + Status InferAttr(ExprAnalysisContext *ctx) override; static ConstNode *CastFrom(ExprNode *node); @@ -1620,7 +1648,7 @@ class ColumnRefNode : public ExprNode { void SetRelationName(const std::string &relation_name) { relation_name_ = relation_name; } - std::string GetColumnName() const { return column_name_; } + const std::string &GetColumnName() const { return column_name_; } void SetColumnName(const std::string &column_name) { column_name_ = column_name; } diff --git a/hybridse/include/vm/catalog.h b/hybridse/include/vm/catalog.h index 7980fdbd5f0..30e68316606 100644 --- a/hybridse/include/vm/catalog.h +++ b/hybridse/include/vm/catalog.h @@ -471,6 +471,7 @@ struct AggrTableInfo { std::string partition_cols; std::string order_by_col; std::string bucket_size; + std::string filter_col; bool operator==(const AggrTableInfo& rhs) const { return aggr_table == rhs.aggr_table && @@ -481,7 +482,8 @@ struct AggrTableInfo { aggr_col == rhs.aggr_col && partition_cols == rhs.partition_cols && order_by_col == rhs.order_by_col && - bucket_size == rhs.bucket_size; + bucket_size == rhs.bucket_size && + filter_col == rhs.filter_col; } }; @@ -514,13 +516,10 @@ class Catalog { return nullptr; } - virtual std::vector GetAggrTables( - const std::string& base_db, - const std::string& base_table, - const std::string& aggr_func, - const std::string& aggr_col, - const std::string& partition_cols, - const std::string& order_col) { + virtual std::vector GetAggrTables(const std::string& base_db, const std::string& base_table, + const std::string& aggr_func, const std::string& aggr_col, + const std::string& partition_cols, const std::string& order_col, + const std::string& filter_col) { return std::vector(); } }; diff --git a/hybridse/include/vm/physical_op.h b/hybridse/include/vm/physical_op.h index e3d5ad10bb7..98b731d04e8 100644 --- a/hybridse/include/vm/physical_op.h +++ b/hybridse/include/vm/physical_op.h @@ -785,7 +785,7 @@ class PhysicalReduceAggregationNode : public PhysicalProjectNode { } virtual ~PhysicalReduceAggregationNode() {} base::Status InitSchema(PhysicalPlanContext *) override; - virtual void Print(std::ostream &output, const std::string &tab) const; + void Print(std::ostream &output, const std::string &tab) const override; ConditionFilter having_condition_; const PhysicalAggregationNode* orig_aggr_ = nullptr; }; @@ -1500,26 +1500,25 @@ class PhysicalRequestAggUnionNode : public PhysicalOpNode { PhysicalRequestAggUnionNode(PhysicalOpNode *request, PhysicalOpNode *raw, PhysicalOpNode *aggr, const RequestWindowOp &window, const RequestWindowOp &aggr_window, bool instance_not_in_window, bool exclude_current_time, bool output_request_row, - const node::FnDefNode *func, const node::ExprNode* agg_col) + const node::CallExprNode *project) : PhysicalOpNode(kPhysicalOpRequestAggUnion, true), window_(window), agg_window_(aggr_window), - func_(func), - agg_col_(agg_col), + project_(project), instance_not_in_window_(instance_not_in_window), exclude_current_time_(exclude_current_time), output_request_row_(output_request_row) { output_type_ = kSchemaTypeTable; - fn_infos_.push_back(&window_.partition_.fn_info()); - fn_infos_.push_back(&window_.sort_.fn_info()); - fn_infos_.push_back(&window_.range_.fn_info()); - fn_infos_.push_back(&window_.index_key_.fn_info()); + AddFnInfo(&window_.partition_.fn_info()); + AddFnInfo(&window_.sort_.fn_info()); + AddFnInfo(&window_.range_.fn_info()); + AddFnInfo(&window_.index_key_.fn_info()); - fn_infos_.push_back(&agg_window_.partition_.fn_info()); - fn_infos_.push_back(&agg_window_.sort_.fn_info()); - fn_infos_.push_back(&agg_window_.range_.fn_info()); - fn_infos_.push_back(&agg_window_.index_key_.fn_info()); + AddFnInfo(&agg_window_.partition_.fn_info()); + AddFnInfo(&agg_window_.sort_.fn_info()); + AddFnInfo(&agg_window_.range_.fn_info()); + AddFnInfo(&agg_window_.index_key_.fn_info()); AddProducers(request, raw, aggr); } @@ -1547,11 +1546,18 @@ class PhysicalRequestAggUnionNode : public PhysicalOpNode { RequestWindowOp window_; RequestWindowOp agg_window_; - const node::FnDefNode* func_ = nullptr; - const node::ExprNode* agg_col_; + + // for long window, each node has only one projection node + const node::CallExprNode* project_; const SchemasContext* parent_schema_context_ = nullptr; private: + void AddProducers(PhysicalOpNode *request, PhysicalOpNode *raw, PhysicalOpNode *aggr) { + AddProducer(request); + AddProducer(raw); + AddProducer(aggr); + } + const bool instance_not_in_window_; const bool exclude_current_time_; @@ -1563,12 +1569,6 @@ class PhysicalRequestAggUnionNode : public PhysicalOpNode { // `EXCLUDE CURRENT_ROW` bool output_request_row_; - void AddProducers(PhysicalOpNode *request, PhysicalOpNode *raw, PhysicalOpNode *aggr) { - AddProducer(request); - AddProducer(raw); - AddProducer(aggr); - } - Schema agg_schema_; }; diff --git a/hybridse/include/vm/simple_catalog.h b/hybridse/include/vm/simple_catalog.h index c8a78bca52a..1e1cd78a2f6 100644 --- a/hybridse/include/vm/simple_catalog.h +++ b/hybridse/include/vm/simple_catalog.h @@ -98,13 +98,10 @@ class SimpleCatalog : public Catalog { bool InsertRows(const std::string &db, const std::string &table, const std::vector &row); - std::vector GetAggrTables( - const std::string& base_db, - const std::string& base_table, - const std::string& aggr_func, - const std::string& aggr_col, - const std::string& partition_cols, - const std::string& order_col) override; + std::vector GetAggrTables(const std::string &base_db, const std::string &base_table, + const std::string &aggr_func, const std::string &aggr_col, + const std::string &partition_cols, const std::string &order_col, + const std::string &filter_col) override; private: bool enable_index_; diff --git a/hybridse/src/passes/physical/long_window_optimized.cc b/hybridse/src/passes/physical/long_window_optimized.cc index 0c8ae8b99b8..48c2a5d1ef7 100644 --- a/hybridse/src/passes/physical/long_window_optimized.cc +++ b/hybridse/src/passes/physical/long_window_optimized.cc @@ -15,17 +15,23 @@ */ #include "passes/physical/long_window_optimized.h" -#include - #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "vm/engine.h" #include "vm/physical_op.h" namespace hybridse { namespace passes { +static const absl::flat_hash_set WHERE_FUNS = { + "count_where", "sum_where", "avg_where", "min_where", "max_where", +}; + LongWindowOptimized::LongWindowOptimized(PhysicalPlanContext* plan_ctx) : TransformUpPysicalPass(plan_ctx) { std::vector windows; const auto* options = plan_ctx_->GetOptions(); @@ -104,16 +110,17 @@ bool LongWindowOptimized::OptimizeWithPreAggr(vm::PhysicalAggregationNode* in, i auto aggr_op = dynamic_cast(projects.GetExpr(idx)); auto window = aggr_op->GetOver(); - auto expr_type = aggr_op->GetChild(0)->GetExprType(); - if (aggr_op->GetChildNum() != 1 || (expr_type != node::kExprColumnRef && expr_type != node::kExprAll)) { - LOG(ERROR) << "Not support aggregation over multiple cols: " << ConcatExprList(aggr_op->children_); + auto s = CheckCallExpr(aggr_op); + if (!s.ok()) { + LOG(ERROR) << s.status(); return false; } const std::string& db_name = orig_data_provider->GetDb(); const std::string& table_name = orig_data_provider->GetName(); std::string func_name = aggr_op->GetFnDef()->GetName(); - std::string aggr_col = ConcatExprList(aggr_op->children_); + std::string aggr_col = ConcatExprList({aggr_op->children_.front()}); + std::string filter_col = std::string(s->filter_col_name); std::string partition_col; if (window->GetPartitions()) { partition_col = ConcatExprList(window->GetPartitions()->children_); @@ -145,7 +152,8 @@ bool LongWindowOptimized::OptimizeWithPreAggr(vm::PhysicalAggregationNode* in, i } } - auto table_infos = catalog_->GetAggrTables(db_name, table_name, func_name, aggr_col, partition_col, order_col); + auto table_infos = + catalog_->GetAggrTables(db_name, table_name, func_name, aggr_col, partition_col, order_col, filter_col); if (table_infos.empty()) { LOG(WARNING) << absl::StrCat("No Pre-aggregation tables exists for ", db_name, ".", table_name, ": ", func_name, "(", aggr_col, ")", " partition by ", partition_col, " order by ", order_col); @@ -210,8 +218,7 @@ bool LongWindowOptimized::OptimizeWithPreAggr(vm::PhysicalAggregationNode* in, i status = plan_ctx_->CreateOp( &request_aggr_union, request, raw, aggr, req_union_op->window(), aggr_window, req_union_op->instance_not_in_window(), req_union_op->exclude_current_time(), - req_union_op->output_request_row(), aggr_op->GetFnDef(), - aggr_op->GetChild(0)); + req_union_op->output_request_row(), aggr_op); if (req_union_op->exclude_current_row_) { request_aggr_union->set_out_request_row(false); } @@ -240,9 +247,9 @@ bool LongWindowOptimized::OptimizeWithPreAggr(vm::PhysicalAggregationNode* in, i LOG(ERROR) << "Fail to create PhysicalReduceAggregationNode: " << status; return false; } - LOG(INFO) << "[LongWindowOptimized] Before transform sql:\n" << (*output)->GetTreeString(); + DLOG(INFO) << "[LongWindowOptimized] Before transform sql:\n" << (*output)->GetTreeString(); *output = reduce_aggr; - LOG(INFO) << "[LongWindowOptimized] After transform sql:\n" << (*output)->GetTreeString(); + DLOG(INFO) << "[LongWindowOptimized] After transform sql:\n" << (*output)->GetTreeString(); return true; } @@ -269,5 +276,117 @@ std::string LongWindowOptimized::ConcatExprList(std::vector exp return str; } + +// type check of count_where condition node +// left -> column ref +// right -> constant +absl::StatusOr CheckCountWhereCond(const node::ExprNode* lhs, const node::ExprNode* rhs) { + if (lhs->GetExprType() != node::ExprType::kExprColumnRef) { + return absl::UnimplementedError(absl::StrCat("expect left as column reference but get ", lhs->GetExprString())); + } + if (rhs->GetExprType() != node::ExprType::kExprPrimary) { + return absl::UnimplementedError(absl::StrCat("expect right as constant but get ", rhs->GetExprString())); + } + + return dynamic_cast(lhs)->GetColumnName(); +} + +// left -> * or column name +// right -> BinaryExpr of +// lhs column name and rhs constant, or versa +// op -> (eq, ne, gt, lt, ge, le) +absl::StatusOr CheckCountWhereArgs(const node::ExprNode* right) { + if (right->GetExprType() != node::ExprType::kExprBinary) { + return absl::UnimplementedError(absl::StrCat("[Long Window] ExprType ", + node::ExprTypeName(right->GetExprType()), + " not implemented as count_where condition")); + } + auto* bin_expr = dynamic_cast(right); + if (bin_expr == nullptr) { + return absl::UnknownError("[Long Window] right can't cast to binary expr"); + } + + auto s1 = CheckCountWhereCond(right->GetChild(0), right->GetChild(1)); + auto s2 = CheckCountWhereCond(right->GetChild(1), right->GetChild(0)); + if (!s1.ok() && !s2.ok()) { + return absl::UnimplementedError( + absl::StrCat("[Long Window] cond as ", right->GetExprString(), " not support: ", s1.status().message())); + } + + switch (bin_expr->GetOp()) { + case node::FnOperator::kFnOpLe: + case node::FnOperator::kFnOpLt: + case node::FnOperator::kFnOpGt: + case node::FnOperator::kFnOpGe: + case node::FnOperator::kFnOpNeq: + case node::FnOperator::kFnOpEq: + break; + default: + return absl::UnimplementedError( + absl::StrCat("[Long Window] filter cond operator ", node::ExprOpTypeName(bin_expr->GetOp()))); + } + + if (s1.ok()) { + return s1.value(); + } + return s2.value(); +} + +// Supported: +// - count(col) or count(*) +// - sum(col) +// - min(col) +// - max(col) +// - avg(col) +// - count_where(col, simple_expr) +// - count_where(*, simple_expr) +// +// simple_expr can be +// - BinaryExpr +// - operand nodes of the expr can only be column ref and const node +// - with operator: +// - eq +// - neq +// - lt +// - gt +// - le +// - ge +absl::StatusOr LongWindowOptimized::CheckCallExpr(const node::CallExprNode* call) { + if (call->GetChildNum() != 1 && call->GetChildNum() != 2) { + return absl::UnimplementedError( + absl::StrCat("expect call function with argument number 1 or 2, but got ", call->GetExprString())); + } + + // count/sum/min/max/avg + auto expr_type = call->GetChild(0)->GetExprType(); + + absl::string_view key_col; + absl::string_view filter_col; + if (expr_type == node::kExprColumnRef) { + auto* col_ref = dynamic_cast(call->GetChild(0)); + key_col = col_ref->GetColumnName(); + } else if (expr_type == node::kExprAll) { + key_col = call->GetChild(0)->GetExprString(); + } else { + return absl::UnimplementedError( + absl::StrCat("[Long Window] first arg to op is not column or * :", call->GetExprString())); + } + + if (call->GetChildNum() == 2) { + if (absl::c_none_of(WHERE_FUNS, [&call](absl::string_view e) { return call->GetFnDef()->GetName() == e; })) { + return absl::UnimplementedError(absl::StrCat(call->GetFnDef()->GetName(), " not implemented")); + } + + // count_where + auto s = CheckCountWhereArgs(call->GetChild(1)); + if (!s.ok()) { + return s.status(); + } + filter_col = s.value(); + } + + return AggInfo{key_col, filter_col}; +} + } // namespace passes } // namespace hybridse diff --git a/hybridse/src/passes/physical/long_window_optimized.h b/hybridse/src/passes/physical/long_window_optimized.h index fa0cc57b3a9..58b54050cf5 100644 --- a/hybridse/src/passes/physical/long_window_optimized.h +++ b/hybridse/src/passes/physical/long_window_optimized.h @@ -19,6 +19,8 @@ #include #include #include + +#include "absl/status/statusor.h" #include "passes/physical/transform_up_physical_pass.h" namespace hybridse { @@ -29,12 +31,25 @@ class LongWindowOptimized : public TransformUpPysicalPass { explicit LongWindowOptimized(PhysicalPlanContext* plan_ctx); ~LongWindowOptimized() {} + public: + // e.g count_where(col1, col2 < 4) + // -> key_col_name = col1, filter_col_name = col2 + struct AggInfo { + absl::string_view key_col_name; + absl::string_view filter_col_name; + }; + private: bool Transform(PhysicalOpNode* in, PhysicalOpNode** output) override; bool VerifySingleAggregation(vm::PhysicalProjectNode* op); bool OptimizeWithPreAggr(vm::PhysicalAggregationNode* in, int idx, PhysicalOpNode** output); + static std::string ConcatExprList(std::vector exprs, const std::string& delimiter = ","); + // Check supported ExprNode, return false if the call expr type is not implemented + // otherwise, return ok status with the agg info + static absl::StatusOr CheckCallExpr(const node::CallExprNode* call); + std::set long_windows_; }; } // namespace passes diff --git a/hybridse/src/testing/test_base.cc b/hybridse/src/testing/test_base.cc index f9e9cbdb4eb..7806a5a8165 100644 --- a/hybridse/src/testing/test_base.cc +++ b/hybridse/src/testing/test_base.cc @@ -55,6 +55,11 @@ void BuildAggTableDef(::hybridse::type::TableDef& table, const std::string& aggr column->set_type(::hybridse::type::kInt64); column->set_name("binlog_offset"); } + { + ::hybridse::type::ColumnDef* column = table.add_columns(); + column->set_type(::hybridse::type::kVarchar); + column->set_name("filter_key"); + } } void BuildTableDef(::hybridse::type::TableDef& table) { // NOLINT diff --git a/hybridse/src/vm/aggregator.h b/hybridse/src/vm/aggregator.h index e43c35daebe..cfe60478ffb 100644 --- a/hybridse/src/vm/aggregator.h +++ b/hybridse/src/vm/aggregator.h @@ -41,6 +41,8 @@ class BaseAggregator { virtual ~BaseAggregator() {} + // update aggregator states by encoded string + // used usually by update states from pre-agg talbe (encoded multi-rows) virtual void Update(const std::string& val) = 0; // output final row @@ -236,7 +238,7 @@ class CountAggregator : public Aggregator { : Aggregator(type, output_schema, 0) {} // val is assumed to be not null - void UpdateValue(const int64_t& val = 1) override { + void UpdateValue(const int64_t& val) override { this->val_ += val; this->counter_++; DLOG(INFO) << "Update " << Type_Name(this->type_) << " val " << val << ", count = " << this->val_; diff --git a/hybridse/src/vm/internal/eval.cc b/hybridse/src/vm/internal/eval.cc new file mode 100644 index 00000000000..844239201f4 --- /dev/null +++ b/hybridse/src/vm/internal/eval.cc @@ -0,0 +1,251 @@ +// Copyright 2022 4Paradigm Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/internal/eval.h" + +#include + +#include "codegen/ir_base_builder.h" +#include "node/node_manager.h" + +namespace hybridse { +namespace vm { +namespace internal { + +absl::StatusOr> EvalCond(const RowParser* parser, const codec::Row& row, + const node::ExprNode* cond) { + const auto* bin_expr = dynamic_cast(cond); + if (bin_expr == nullptr) { + return absl::InvalidArgumentError("can't evaluate expr other than binary expr"); + } + + auto tp = ExtractCompareType(parser, bin_expr); + if (!tp.ok()) { + return tp.status(); + } + + const auto* left = bin_expr->GetChild(0); + const auto* right = bin_expr->GetChild(1); + + switch (tp.value()) { + case type::kBool: { + return EvalBinaryExpr(parser, row, bin_expr->GetOp(), left, right); + } + case type::kInt16: { + return EvalBinaryExpr(parser, row, bin_expr->GetOp(), left, right); + } + case type::kInt32: + case type::kDate: { + return EvalBinaryExpr(parser, row, bin_expr->GetOp(), left, right); + } + case type::kTimestamp: + case type::kInt64: { + return EvalBinaryExpr(parser, row, bin_expr->GetOp(), left, right); + } + case type::kFloat: { + return EvalBinaryExpr(parser, row, bin_expr->GetOp(), left, right); + } + case type::kDouble: { + return EvalBinaryExpr(parser, row, bin_expr->GetOp(), left, right); + } + case type::kVarchar: { + return EvalBinaryExpr(parser, row, bin_expr->GetOp(), left, right); + } + default: + break; + } + + return absl::UnimplementedError(cond->GetExprString()); +} + +absl::StatusOr> EvalCondWithAggRow(const RowParser* parser, const codec::Row& row, + const node::ExprNode* cond, absl::string_view filter_col_name) { + const auto* bin_expr = dynamic_cast(cond); + if (bin_expr == nullptr) { + return absl::InvalidArgumentError("can't evaluate expr other than binary expr"); + } + + std::string filter = std::string(filter_col_name); + + // if value of filter_col_name is NULL + if (parser->IsNull(row, filter)) { + return std::nullopt; + } + + std::string filter_val; + parser->GetString(row, filter, &filter_val); + + const auto* left = bin_expr->GetChild(0); + const auto* right = bin_expr->GetChild(1); + node::DataType op_type; + + if (left->GetExprType() == node::kExprColumnRef) { + auto* const_node = dynamic_cast(right); + if (const_node == nullptr) { + return absl::InvalidArgumentError("expect right node as const node for evaluation"); + } + op_type = const_node->GetDataType(); + if (const_node->IsNull()) { + return std::nullopt; + } + + switch (op_type) { + case node::DataType::kBool: { + bool v; + if (!absl::SimpleAtob(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to bool")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), v, + const_node->GetAs().value_or(std::nullopt)); + } + case node::DataType::kInt16: { + int32_t v; + if (!absl::SimpleAtoi(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to int32_t")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), static_cast(v), + const_node->GetAs().value_or(std::nullopt)); + } + case node::DataType::kInt32: + case node::DataType::kDate: { + int32_t v; + if (!absl::SimpleAtoi(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to int32_t")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), v, + const_node->GetAs().value_or(std::nullopt)); + } + case node::DataType::kTimestamp: + case node::DataType::kInt64: { + int64_t v; + if (!absl::SimpleAtoi(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to int64_t")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), v, + const_node->GetAs().value_or(std::nullopt)); + } + case node::DataType::kFloat: { + float v; + if (!absl::SimpleAtof(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to flat")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), v, + const_node->GetAs().value_or(std::nullopt)); + } + case node::DataType::kDouble: { + double v; + if (!absl::SimpleAtod(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to double")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), v, + const_node->GetAs().value_or(std::nullopt)); + } + case node::DataType::kVarchar: { + return EvalSimpleBinaryExpr(bin_expr->GetOp(), filter_val, + const_node->GetAs().value_or(std::nullopt)); + } + default: + break; + } + } else if (right->GetExprType() == node::kExprColumnRef) { + auto* const_node = dynamic_cast(left); + if (const_node == nullptr) { + return absl::InvalidArgumentError("expect left node as const node for evaluation"); + } + op_type = const_node->GetDataType(); + + if (const_node->IsNull()) { + return std::nullopt; + } + + switch (op_type) { + case node::DataType::kBool: { + bool v; + if (!absl::SimpleAtob(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to bool")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), const_node->GetAs().value_or(std::nullopt), + v); + } + case node::DataType::kInt16: { + int32_t v; + if (!absl::SimpleAtoi(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to int32_t")); + } + return EvalSimpleBinaryExpr( + bin_expr->GetOp(), const_node->GetAs().value_or(std::nullopt), static_cast(v)); + } + case node::DataType::kInt32: + case node::DataType::kDate: { + int32_t v; + if (!absl::SimpleAtoi(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to int32_t")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), + const_node->GetAs().value_or(std::nullopt), v); + } + case node::DataType::kTimestamp: + case node::DataType::kInt64: { + int64_t v; + if (!absl::SimpleAtoi(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to int64_t")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), + const_node->GetAs().value_or(std::nullopt), v); + } + case node::DataType::kFloat: { + float v; + if (!absl::SimpleAtof(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to flat")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), const_node->GetAs().value_or(std::nullopt), + v); + } + case node::DataType::kDouble: { + double v; + if (!absl::SimpleAtod(filter_val, &v)) { + return absl::InvalidArgumentError(absl::StrCat("can't cast ", filter_val, " to double")); + } + return EvalSimpleBinaryExpr(bin_expr->GetOp(), + const_node->GetAs().value_or(std::nullopt), v); + } + case node::DataType::kVarchar: { + return EvalSimpleBinaryExpr( + bin_expr->GetOp(), const_node->GetAs().value_or(std::nullopt), filter_val); + } + default: + break; + } + } + + return absl::InvalidArgumentError(absl::StrCat("unsupport binary op: ", cond->GetExprString())); +} + +absl::StatusOr ExtractCompareType(const RowParser* parser, const node::BinaryExpr* node) { + if (node->GetChild(0)->GetExprType() == node::kExprColumnRef && + node->GetChild(1)->GetExprType() == node::kExprPrimary) { + return parser->GetType(*dynamic_cast(node->GetChild(0))); + } + if (node->GetChild(1)->GetExprType() == node::kExprColumnRef && + node->GetChild(0)->GetExprType() == node::kExprPrimary) { + return parser->GetType(*dynamic_cast(node->GetChild(1))); + } + + return absl::UnimplementedError(absl::StrCat("Evaluating type for binary expr '", node->GetExprString())); +} + + +} // namespace internal +} // namespace vm +} // namespace hybridse diff --git a/hybridse/src/vm/internal/eval.h b/hybridse/src/vm/internal/eval.h new file mode 100644 index 00000000000..7126729c8cb --- /dev/null +++ b/hybridse/src/vm/internal/eval.h @@ -0,0 +1,188 @@ +// Copyright 2022 4Paradigm Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: eval.h +// ----------------------------------------------------------------------------- +// +// Defines some runner evaluation related helper functions. +// Used by 'vm/runner.{h, cc}' where codegen evaluation is skiped, +// likely in long window runner nodes +// +// ----------------------------------------------------------------------------- + +#ifndef HYBRIDSE_SRC_VM_INTERNAL_EVAL_H_ +#define HYBRIDSE_SRC_VM_INTERNAL_EVAL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "codec/row.h" +#include "node/expr_node.h" +#include "node/node_enum.h" +#include "vm/schemas_context.h" + +namespace hybridse { +namespace vm { +namespace internal { + +// extract value from expr node +// limited implementation since it only expect node one of +// * ColumnRefNode +// * ConstNode +template +absl::StatusOr> ExtractValue(const RowParser* parser, const codec::Row& row, + const node::ExprNode* node) { + if (node->GetExprType() == node::ExprType::kExprPrimary) { + const auto* const_node = dynamic_cast(node); + return const_node->GetAs(); + } + + if (node->GetExprType() == node::ExprType::kExprColumnRef) { + const auto* column_ref = dynamic_cast(node); + if (parser->IsNull(row, *column_ref)) { + return std::nullopt; + } + + if constexpr (std::is_same_v) { + std::string data; + if (0 == parser->GetString(row, *column_ref, &data)) { + return data; + } + } else if constexpr (std::is_same_v) { + bool v = false; + if (0 == parser->GetValue(row, *column_ref, type::kBool, &v)) { + return v; + } + } else if constexpr (std::is_same_v) { + int16_t v = 0; + if (0 == parser->GetValue(row, *column_ref, type::kInt16, &v)) { + return v; + } + } else if constexpr (std::is_same_v) { + int32_t v = 0; + if (0 == parser->GetValue(row, *column_ref, type::kInt32, &v)) { + return v; + } + } else if constexpr (std::is_same_v) { + int64_t v = 0; + if (0 == parser->GetValue(row, *column_ref, type::kInt64, &v)) { + return v; + } + } else if constexpr (std::is_same_v) { + float v = 0.0; + if (0 == parser->GetValue(row, *column_ref, type::kFloat, &v)) { + return v; + } + } else if constexpr (std::is_same_v) { + double v = 0.0; + if (0 == parser->GetValue(row, *column_ref, type::kDouble, &v)) { + return v; + } + } + + return absl::UnimplementedError("not able to get value from a type different from schema"); + } + + return absl::UnimplementedError( + absl::StrCat("invalid node: ", node::ExprTypeName(node->GetExprType()), " -> ", node->GetExprString())); +} + +template +std::ostream& operator<<(std::ostream& os, const std::optional& val) { + if constexpr (std::is_same_v) { + return os << (val.has_value() ? absl::StrCat("\"", val.value(), "\"") : "NULL"); + } else { + return os << (val.has_value() ? std::to_string(val.value()) : "NULL"); + } +} + +template +std::optional EvalSimpleBinaryExpr(node::FnOperator op, const std::optional& lhs, + const std::optional& rhs) { + DLOG(INFO) << "[EvalSimpleBinaryExpr] " << lhs << " " << node::ExprOpTypeName(op) << " " << rhs; + + if (!lhs.has_value() || !rhs.has_value()) { + return std::nullopt; + } + + switch (op) { + case node::FnOperator::kFnOpLt: + return lhs < rhs; + case node::FnOperator::kFnOpLe: + return lhs <= rhs; + case node::FnOperator::kFnOpGt: + return lhs > rhs; + case node::FnOperator::kFnOpGe: + return lhs >= rhs; + case node::FnOperator::kFnOpEq: + return lhs == rhs; + case node::FnOperator::kFnOpNeq: + return lhs != rhs; + default: + break; + } + + return std::nullopt; +} + +template +absl::StatusOr> EvalBinaryExpr(const RowParser* parser, const codec::Row& row, node::FnOperator op, + const node::ExprNode* lhs, const node::ExprNode* rhs) { + absl::Status ret = absl::OkStatus(); + auto ls = ExtractValue(parser, row, lhs); + auto rs = ExtractValue(parser, row, rhs); + ret.Update(ls.status()); + ret.Update(rs.status()); + if (ret.ok()) { + return EvalSimpleBinaryExpr(op, ls.value(), rs.value()); + } + + return ret; +} + +// evaluate the condition expr node +// +// implementation is limited +// * only assume `cond` as `BinaryExprNode`, and supports six basic compassion operators +// * no type infer, the type of ColumnRefNode is used +// +// returns compassion result +// * true/false/NULL +// * invalid input -> InvalidStatus +absl::StatusOr> EvalCond(const RowParser* parser, const codec::Row& row, + const node::ExprNode* cond); + +// evaluate the condition expr same as `EvalCond` +// but inputed `row` and schema is from pre-agg table. +// The expr is also only supported as Binary Expr as 'col < constant', but col name to the +// pre-agg table is already defined as 'filter_key', instead taken from ColumnRefNode kid of binary expr node +// +// * type of const node is used for compassion +absl::StatusOr> EvalCondWithAggRow(const RowParser* parser, const codec::Row& row, + const node::ExprNode* cond, absl::string_view filter_col_name); + +// extract compare type for the input binary expr +// +// already assume the input binary expr as style of 'ColumnRefNode op ConstNode' +// and the type of ColumnRefNode is returned +absl::StatusOr ExtractCompareType(const RowParser* parser, const node::BinaryExpr* bin_expr); + +} // namespace internal +} // namespace vm +} // namespace hybridse + +#endif // HYBRIDSE_SRC_VM_INTERNAL_EVAL_H_ diff --git a/hybridse/src/vm/runner.cc b/hybridse/src/vm/runner.cc index fd6191f96c6..1e5c0d5fc85 100644 --- a/hybridse/src/vm/runner.cc +++ b/hybridse/src/vm/runner.cc @@ -21,11 +21,14 @@ #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" #include "base/texttable.h" #include "udf/udf.h" #include "vm/catalog_wrapper.h" #include "vm/core_api.h" +#include "vm/internal/eval.h" #include "vm/jit_runtime.h" #include "vm/mem_catalog.h" @@ -572,10 +575,8 @@ ClusterTask RunnerBuilder::BuildRequestAggUnionTask(PhysicalOpNode* node, Status } auto op = dynamic_cast(node); RequestAggUnionRunner* runner = nullptr; - CreateRunner( - &runner, id_++, node->schemas_ctx(), op->GetLimitCnt(), - op->window().range_, op->exclude_current_time(), - op->output_request_row(), op->func_, op->agg_col_); + CreateRunner(&runner, id_++, node->schemas_ctx(), op->GetLimitCnt(), op->window().range_, + op->exclude_current_time(), op->output_request_row(), op->project_); Key index_key; if (!op->instance_not_in_window()) { index_key = op->window_.index_key(); @@ -2354,9 +2355,7 @@ void Runner::PrintData(std::ostringstream& oss, oss << t; } -void Runner::PrintRow(std::ostringstream& oss, - const vm::SchemasContext* schema_list, - Row row) { +void Runner::PrintRow(std::ostringstream& oss, const vm::SchemasContext* schema_list, const Row& row) { std::vector row_view_list; ::hybridse::base::TextTable t('-', '|', '+'); // Add Header @@ -2410,6 +2409,12 @@ void Runner::PrintRow(std::ostringstream& oss, oss << t; } +std::string Runner::GetPrettyRow(const vm::SchemasContext* schema_list, const Row& row) { + std::ostringstream os; + PrintRow(os, schema_list, row); + return os.str(); +} + bool Runner::ExtractRows(std::shared_ptr handlers, std::vector& out_rows) { @@ -2690,7 +2695,7 @@ bool RequestAggUnionRunner::InitAggregator() { if (agg_col_->GetExprType() == node::kExprColumnRef) { agg_col_type_ = producers_[1]->row_parser()->GetType(agg_col_name_); } else if (agg_col_->GetExprType() == node::kExprAll) { - if (agg_type_ != kCount) { + if (agg_type_ != kCount && agg_type_ != kCountWhere) { LOG(ERROR) << "only support " << ExprTypeName(agg_col_->GetExprType()) << "on count op"; return false; } @@ -2705,14 +2710,19 @@ bool RequestAggUnionRunner::InitAggregator() { std::unique_ptr RequestAggUnionRunner::CreateAggregator() const { switch (agg_type_) { case kSum: + case kSumWhere: return MakeOverflowAggregator(agg_col_type_, *output_schemas_->GetOutputSchema()); case kAvg: + case kAvgWhere: return std::make_unique(agg_col_type_, *output_schemas_->GetOutputSchema()); case kCount: + case kCountWhere: return std::make_unique(agg_col_type_, *output_schemas_->GetOutputSchema()); case kMin: + case kMinWhere: return MakeSameTypeAggregator(agg_col_type_, *output_schemas_->GetOutputSchema()); case kMax: + case kMaxWhere: return MakeSameTypeAggregator(agg_col_type_, *output_schemas_->GetOutputSchema()); default: LOG(ERROR) << "RequestAggUnionRunner does not support for op " << func_->GetName(); @@ -2747,7 +2757,7 @@ std::shared_ptr RequestAggUnionRunner::Run( for (size_t i = 0; i < union_inputs.size(); i++) { std::ostringstream sss; PrintData(sss, producers_[i + 1]->output_schemas(), union_inputs[i]); - LOG(INFO) << "union input " << i << ": " << sss.str(); + LOG(INFO) << "union input " << i << ":\n" << sss.str(); } } @@ -2770,7 +2780,7 @@ std::shared_ptr RequestAggUnionRunner::Run( std::ostringstream sss; PrintData(sss, producers_[i + 1]->output_schemas(), union_segments[i]); - LOG(INFO) << "union output " << i << ": " << sss.str(); + LOG(INFO) << "union output " << i << ":\n" << sss.str(); } } @@ -2786,11 +2796,6 @@ std::shared_ptr RequestAggUnionRunner::Run( exclude_current_time_, !output_request_row_); } - if (ctx.is_debug()) { - std::ostringstream oss; - PrintData(oss, output_schemas(), window); - LOG(INFO) << "Request AGG UNION output: " << oss.str(); - } return window; } @@ -2822,16 +2827,12 @@ std::shared_ptr RequestAggUnionRunner::RequestUnionWindow( int64_t max_size = 0; if (ts_gen >= 0) { if (window_range.frame_type_ != Window::kFrameRows) { - start = (ts_gen + window_range.start_offset_) < 0 - ? 0 - : (ts_gen + window_range.start_offset_); + start = (ts_gen + window_range.start_offset_) < 0 ? 0 : (ts_gen + window_range.start_offset_); } if (exclude_current_time && 0 == window_range.end_offset_) { end = (ts_gen - 1) < 0 ? 0 : (ts_gen - 1); } else { - end = (ts_gen + window_range.end_offset_) < 0 - ? 0 - : (ts_gen + window_range.end_offset_); + end = (ts_gen + window_range.end_offset_) < 0 ? 0 : (ts_gen + window_range.end_offset_); } rows_start_preceding = window_range.start_row_; max_size = window_range.max_size_; @@ -2840,15 +2841,33 @@ std::shared_ptr RequestAggUnionRunner::RequestUnionWindow( auto aggregator = CreateAggregator(); auto update_base_aggregator = [aggregator = aggregator.get(), row_parser = base_row_parser, this](const Row& row) { + DLOG(INFO) << "[Update Base]\n" << GetPrettyRow(row_parser->schema_ctx(), row); if (!agg_col_name_.empty() && row_parser->IsNull(row, agg_col_name_)) { return; } + if (cond_ != nullptr) { + // for those condition exists and evaluated to NULL/false + // will apply to functions `*_where` + // include `count_where` has supported, or `{min/max/avg/sum}_where` support later + auto matches = internal::EvalCond(row_parser, row, cond_); + DLOG(INFO) << "[Update Base Filter] Evaluate result of " << cond_->GetExprString() << ": " + << PrintEvalValue(matches); + if (!matches.ok()) { + LOG(ERROR) << matches.status(); + return; + } + if (false == matches->value_or(false)) { + return; + } + } + auto type = aggregator->type(); - if (agg_type_ == kCount) { + if (agg_type_ == kCount || agg_type_ == kCountWhere) { dynamic_cast*>(aggregator)->UpdateValue(1); return; } + if (agg_col_name_.empty()) { return; } @@ -2897,20 +2916,33 @@ std::shared_ptr RequestAggUnionRunner::RequestUnionWindow( } }; - auto update_agg_aggregator = [aggregator = aggregator.get(), row_parser = agg_row_parser](const Row& row) { + auto update_agg_aggregator = [aggregator = aggregator.get(), row_parser = agg_row_parser, this](const Row& row) { + DLOG(INFO) << "[Update Agg]\n" << GetPrettyRow(row_parser->schema_ctx(), row); if (row_parser->IsNull(row, "agg_val")) { return; } + if (cond_ != nullptr) { + auto matches = internal::EvalCondWithAggRow(row_parser, row, cond_, "filter_key"); + DLOG(INFO) << "[Update Agg Filter] Evaluate result of " << cond_->GetExprString() << ": " + << PrintEvalValue(matches); + if (!matches.ok()) { + LOG(ERROR) << matches.status(); + return; + } + if (false == matches->value_or(false)) { + return; + } + } + std::string agg_val; row_parser->GetString(row, "agg_val", &agg_val); aggregator->Update(agg_val); }; int64_t cnt = 0; - auto range_status = window_range.GetWindowPositionStatus( - cnt > rows_start_preceding, window_range.end_offset_ < 0, - request_key < start); + auto range_status = window_range.GetWindowPositionStatus(cnt > rows_start_preceding, window_range.end_offset_ < 0, + request_key < start); if (output_request_row) { update_base_aggregator(request); } @@ -2930,39 +2962,57 @@ std::shared_ptr RequestAggUnionRunner::RequestUnionWindow( auto agg_it = union_segments[1]->GetIterator(); if (agg_it) { - agg_it->Seek(end); + agg_it->Seek(end); } else { LOG(WARNING) << "Agg window is empty. Use base window only"; } // we'll iterate over the following ranges: - // - base(end_base, end] if end_base < end - // - agg[start_base, end_base] - // - base[start, start_base) if start < start_base - int64_t end_base = start; - int64_t start_base = start + 1; - if (agg_it && agg_it->Valid()) { - int64_t ts_start = agg_it->GetKey(); + // 1. base(end_base, end] if end_base < end + // 2. agg[start_base, end_base] + // 3. base[start, start_base) if start < start_base + // + // | start .. | start_base ... end_base | .. end | + // | <----------------- iterate order (end to start) + // + // when start_base > end_base, step 2 skipped, fallback as + // | start .. | end_base .. end | + // | <----------------- iterate order (end to start) + std::optional end_base = start; + std::optional start_base = {}; + if (agg_it) { + int64_t ts_start = -1; int64_t ts_end = -1; - agg_row_parser->GetValue(agg_it->GetValue(), "ts_end", type::Type::kTimestamp, &ts_end); - if (ts_end > end) { // [ts_start, ts_end] covers beyond the [start, end] region - end_base = ts_start; - agg_it->Next(); - if (agg_it->Valid()) { - agg_row_parser->GetValue(agg_it->GetValue(), "ts_end", type::Type::kTimestamp, &ts_end); - end_base = ts_end; - } else { - // only base table will be used - end_base = start; - start_base = start + 1; + // iterate through agg_it and find the first one that + // - agg record inside window frame + // - key (ts_start) >= start + // - ts_end <= end + while (agg_it->Valid()) { + ts_start = agg_it->GetKey(); + agg_row_parser->GetValue(agg_it->GetValue(), "ts_end", type::Type::kTimestamp, &ts_end); + if (ts_end <= end) { + break; } - } else { - end_base = ts_end; + + agg_it->Next(); } + + if (ts_end != -1 && ts_start >= start) { + // first agg record inside window frame + end_base = ts_end; + // assign a value to start_base so agg aggregate happens + start_base = start + 1; + } /* else only base table will be used */ } - // iterate over base table from end (inclusive) to end_base (exclusive) + // NOTE: start_base is not correct util step 2 finished + DLOG(INFO) << absl::Substitute( + "[RequestUnion]($6) {start=$0, start_base=$1, end_base=$2, end=$3, base_key=$4, agg_key=$5}", start, + start_base.value_or(-1), end_base.value_or(-1), end, base_it->GetKey(), (agg_it ? agg_it->GetKey() : -1), + (cond_ ? cond_->GetExprString() : "")); + + // 1. iterate over base table from [end, end_base) end (inclusive) to end_base (exclusive) if (end_base < end) { while (base_it->Valid()) { if (max_size > 0 && cnt >= max_size) { @@ -2985,47 +3035,118 @@ std::shared_ptr RequestAggUnionRunner::RequestUnionWindow( } } - // iterate over agg table from end_base until start (both inclusive) - int64_t last_ts_start = INT64_MAX; - while (agg_it && agg_it->Valid()) { + // 2. iterate over agg table from end_base until start_base (both inclusive) + int64_t prev_ts_start = INT64_MAX; + while (start_base.has_value() && start_base <= end_base && agg_it != nullptr && agg_it->Valid()) { if (max_size > 0 && cnt >= max_size) { break; } - int64_t ts_start = agg_it->GetKey(); - // for mem-table, updating will inserts duplicate entries - if (last_ts_start == ts_start) { - DLOG(INFO) << "Found duplicate entries in agg table for ts_start = " << ts_start; - continue; - } - last_ts_start = ts_start; + if (cond_ == nullptr) { + int64_t ts_start = agg_it->GetKey(); + const Row& row = agg_it->GetValue(); + if (prev_ts_start == ts_start) { + DLOG(INFO) << "Found duplicate entries in agg table for ts_start = " << ts_start; + agg_it->Next(); + continue; + } + prev_ts_start = ts_start; + + int64_t ts_end = -1; + agg_row_parser->GetValue(row, "ts_end", type::Type::kTimestamp, &ts_end); + int num_rows = 0; + agg_row_parser->GetValue(row, "num_rows", type::Type::kInt32, &num_rows); + + // FIXME(zhanghao): check cnt and rows_start_preceding meanings + int next_incr = num_rows > 0 ? num_rows - 1 : 0; + auto range_status = window_range.GetWindowPositionStatus(cnt + next_incr > rows_start_preceding, + ts_start > end, ts_start < start); + if ((max_size > 0 && cnt + next_incr >= max_size) || WindowRange::kExceedWindow == range_status) { + start_base = ts_end + 1; + break; + } + if (WindowRange::kInWindow == range_status) { + update_agg_aggregator(row); + cnt += num_rows; + } - const Row& row = agg_it->GetValue(); - int64_t ts_end = -1; - agg_row_parser->GetValue(row, "ts_end", type::Type::kTimestamp, &ts_end); - int num_rows = 0; - agg_row_parser->GetValue(row, "num_rows", type::Type::kInt32, &num_rows); - - // FIXME(zhanghao): check cnt and rows_start_preceding meanings - int next_incr = num_rows > 0 ? num_rows - 1 : 0; - auto range_status = window_range.GetWindowPositionStatus(cnt + next_incr > rows_start_preceding, ts_start > end, - ts_start < start); - if ((max_size > 0 && cnt + next_incr >= max_size) || WindowRange::kExceedWindow == range_status) { - start_base = ts_end + 1; - break; - } - if (WindowRange::kInWindow == range_status) { - update_agg_aggregator(row); - cnt += num_rows; - } + start_base = ts_start; + agg_it->Next(); + } else { + const int64_t ts_start = agg_it->GetKey(); + + // for agg rows has filter_key + // max_size check should happen after iterate all agg rows for the same key + std::vector key_agg_rows; + std::set filter_val_set; + + int total_rows = 0; + int64_t ts_end_range = -1; + agg_row_parser->GetValue(agg_it->GetValue(), "ts_end", type::Type::kTimestamp, &ts_end_range); + while (agg_it->Valid() && ts_start == agg_it->GetKey()) { + const Row& drow = agg_it->GetValue(); + + std::string filter_val; + if (agg_row_parser->IsNull(drow, "filter_key")) { + LOG(ERROR) << "filter_key is null for *_where op"; + agg_it->Next(); + continue; + } + if (0 != agg_row_parser->GetString(drow, "filter_key", &filter_val)) { + LOG(ERROR) << "failed to get value of filter_key"; + agg_it->Next(); + continue; + } + + if (prev_ts_start == ts_start && filter_val_set.count(filter_val) != 0) { + DLOG(INFO) << "Found duplicate entries in agg table for ts_start = " << ts_start + << ", filter_key=" << filter_val; + agg_it->Next(); + continue; + } + + prev_ts_start = ts_start; + filter_val_set.insert(filter_val); - start_base = ts_start; - agg_it->Next(); + int num_rows = 0; + agg_row_parser->GetValue(drow, "num_rows", type::Type::kInt32, &num_rows); + + if (num_rows > 0) { + total_rows += num_rows; + key_agg_rows.push_back(drow); + } + + agg_it->Next(); + } + + int next_incr = total_rows > 0 ? total_rows - 1 : 0; + auto range_status = window_range.GetWindowPositionStatus(cnt + next_incr > rows_start_preceding, + ts_start > end, ts_start < start); + if ((max_size > 0 && cnt + next_incr >= max_size) || WindowRange::kExceedWindow == range_status) { + start_base = ts_end_range + 1; + break; + } + if (WindowRange::kInWindow == range_status) { + for (auto& row : key_agg_rows) { + update_agg_aggregator(row); + } + cnt += total_rows; + } + + start_base = ts_start; + } } - if (start_base > 0) { - // iterate over base table from start_base (exclusive) to start (inclusive) - base_it->Seek(start_base - 1); + // 3. iterate over base table from start_base (exclusive) to start (inclusive) + // + // if start_base is empty -> + // step 2 skiped, this step only agg on key = start + // otherwise -> + // if start_base is 0 -> skiped + // otherwise -> agg over [start, start_base) + int64_t step_3_start = start_base.value_or(start + 1); + if (step_3_start > 0) { + base_it->Seek(step_3_start - 1); while (base_it->Valid()) { int64_t ts = base_it->GetKey(); auto range_status = window_range.GetWindowPositionStatus(static_cast(cnt) > rows_start_preceding, @@ -3047,6 +3168,16 @@ std::shared_ptr RequestAggUnionRunner::RequestUnionWindow( return window_table; } +std::string RequestAggUnionRunner::PrintEvalValue(const absl::StatusOr>& val) { + std::ostringstream os; + if (!val.ok()) { + os << val.status(); + } else { + os << (val->has_value() ? (val->value() ? "TRUE" : "FALSE") : "NULL"); + } + return os.str(); +} + std::shared_ptr ReduceRunner::Run( RunnerContext& ctx, const std::vector>& inputs) { @@ -3064,11 +3195,6 @@ std::shared_ptr ReduceRunner::Run( return std::shared_ptr(); } auto table = std::dynamic_pointer_cast(input); - if (ctx.is_debug()) { - std::ostringstream oss; - PrintData(oss, producers_[0]->output_schemas(), table); - LOG(WARNING) << "ReduceRunner input: " << oss.str(); - } auto parameter = ctx.GetParameterRow(); if (having_condition_.Valid() && !having_condition_.Gen(table, parameter)) { @@ -3083,11 +3209,6 @@ std::shared_ptr ReduceRunner::Run( } std::shared_ptr row_handler = std::make_shared(iter->GetValue()); - if (ctx.is_debug()) { - std::ostringstream oss; - PrintData(oss, producers_[0]->output_schemas(), row_handler); - LOG(WARNING) << "ReduceRunner output: " << oss.str(); - } return row_handler; } diff --git a/hybridse/src/vm/runner.h b/hybridse/src/vm/runner.h index 0d1ec851b66..301b95b5f18 100644 --- a/hybridse/src/vm/runner.h +++ b/hybridse/src/vm/runner.h @@ -24,6 +24,9 @@ #include #include #include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "base/fe_status.h" #include "codec/fe_row_codec.h" #include "node/node_manager.h" @@ -518,9 +521,9 @@ class Runner : public node::NodeBase { static void PrintData(std::ostringstream& oss, const vm::SchemasContext* schema_list, std::shared_ptr data); - static void PrintRow(std::ostringstream& oss, - const vm::SchemasContext* schema_list, - Row row); + static void PrintRow(std::ostringstream& oss, const vm::SchemasContext* schema_list, const Row& row); + static std::string GetPrettyRow(const vm::SchemasContext* schema_list, const Row& row); + static const bool IsProxyRunner(const RunnerType& type) { return kRunnerRequestRunProxy == type || kRunnerBatchRequestRunProxy == type; @@ -990,18 +993,23 @@ class RequestUnionRunner : public Runner { class RequestAggUnionRunner : public Runner { public: RequestAggUnionRunner(const int32_t id, const SchemasContext* schema, const int32_t limit_cnt, const Range& range, - bool exclude_current_time, bool output_request_row, const node::FnDefNode* func, - const node::ExprNode* agg_col) + bool exclude_current_time, bool output_request_row, const node::CallExprNode* project) : Runner(id, kRunnerRequestAggUnion, schema, limit_cnt), range_gen_(range), exclude_current_time_(exclude_current_time), output_request_row_(output_request_row), - func_(func), - agg_col_(agg_col) { - if (agg_col_->GetExprType() == node::kExprColumnRef) { - agg_col_name_ = dynamic_cast(agg_col_)->GetColumnName(); + func_(project->GetFnDef()), + agg_col_(project->GetChild(0)) { + if (agg_col_->GetExprType() == node::kExprColumnRef) { + agg_col_name_ = dynamic_cast(agg_col_)->GetColumnName(); + } /* for kAllExpr like count(*), agg_col_name_ is empty */ + + if (project->GetChildNum() >= 2) { + // assume second kid of project as filter condition + // function support check happens in compile + cond_ = project->GetChild(1); + } } -} bool InitAggregator(); std::shared_ptr Run(RunnerContext& ctx, @@ -1015,13 +1023,20 @@ class RequestAggUnionRunner : public Runner { windows_union_gen_.AddWindowUnion(window, runner); } + static std::string PrintEvalValue(const absl::StatusOr>& val); + private: enum AggType { kSum, kCount, kAvg, kMin, - kMax + kMax, + kCountWhere, + kSumWhere, + kAvgWhere, + kMinWhere, + kMaxWhere, }; RequestWindowUnionGenerator windows_union_gen_; @@ -1038,10 +1053,23 @@ class RequestAggUnionRunner : public Runner { std::string agg_col_name_; type::Type agg_col_type_; + // the filter condition for count_where + // simple compassion binary expr like col < 0 is supported + node::ExprNode* cond_ = nullptr; + std::unique_ptr CreateAggregator() const; - static inline const std::unordered_map agg_type_map_ = { - {"sum", kSum}, {"count", kCount}, {"avg", kAvg}, {"min", kMin}, {"max", kMax}, - }; + + static inline const absl::flat_hash_map agg_type_map_ = { + {"sum", kSum}, + {"count", kCount}, + {"avg", kAvg}, + {"min", kMin}, + {"max", kMax}, + {"count_where", kCountWhere}, + {"sum_where", kSumWhere}, + {"avg_where", kAvgWhere}, + {"min_where", kMinWhere}, + {"max_where", kMaxWhere}}; }; class PostRequestUnionRunner : public Runner { diff --git a/hybridse/src/vm/schemas_context.cc b/hybridse/src/vm/schemas_context.cc index 469aad85ade..b70adb9df43 100644 --- a/hybridse/src/vm/schemas_context.cc +++ b/hybridse/src/vm/schemas_context.cc @@ -787,7 +787,10 @@ int32_t RowParser::GetString(const Row& row, const std::string& col, std::string const codec::RowView& row_view = row_view_list_[schema_idx]; const char* ch = nullptr; uint32_t str_size; - row_view.GetValue(row.buf(schema_idx), col_idx, &ch, &str_size); + int ret = row_view.GetValue(row.buf(schema_idx), col_idx, &ch, &str_size); + if (0 != ret) { + return ret; + } std::string tmp(ch, str_size); val->swap(tmp); diff --git a/hybridse/src/vm/simple_catalog.cc b/hybridse/src/vm/simple_catalog.cc index 929a08e1776..76093858ace 100644 --- a/hybridse/src/vm/simple_catalog.cc +++ b/hybridse/src/vm/simple_catalog.cc @@ -217,15 +217,12 @@ bool SimpleCatalogTableHandler::AddRow(const Row row) { return true; } -std::vector SimpleCatalog::GetAggrTables( - const std::string& base_db, - const std::string& base_table, - const std::string& aggr_func, - const std::string& aggr_col, - const std::string& partition_cols, - const std::string& order_col) { - ::hybridse::vm::AggrTableInfo info = {"aggr_" + base_table, "aggr_db", base_db, base_table, - aggr_func, aggr_col, partition_cols, order_col, "1000"}; +std::vector SimpleCatalog::GetAggrTables(const std::string &base_db, const std::string &base_table, + const std::string &aggr_func, const std::string &aggr_col, + const std::string &partition_cols, const std::string &order_col, + const std::string &filter_col) { + ::hybridse::vm::AggrTableInfo info = {"aggr_" + base_table, "aggr_db", base_db, base_table, aggr_func, aggr_col, + partition_cols, order_col, "1000", filter_col}; return {info}; } diff --git a/hybridse/src/vm/transform_request_mode_test.cc b/hybridse/src/vm/transform_request_mode_test.cc index 3683335def0..cde33bc289c 100644 --- a/hybridse/src/vm/transform_request_mode_test.cc +++ b/hybridse/src/vm/transform_request_mode_test.cc @@ -563,46 +563,63 @@ TEST_F(TransformRequestModePassOptimizedTest, SplitAggregationOptimizedTest) { } TEST_F(TransformRequestModePassOptimizedTest, LongWindowOptimizedTest) { + // five long window agg applied const std::string sql = - "SELECT col1, sum(col2) OVER w1, col2+1, add(col2, col1), count(col2) OVER w1, " - "sum(col2) over w2 as w1_col2_sum , sum(col2) over w3 FROM t1\n" - "WINDOW w1 AS (PARTITION BY col1 ORDER BY col5 ROWS_RANGE BETWEEN 3m PRECEDING AND CURRENT ROW)," - "w2 AS (PARTITION BY col1,col2 ORDER BY col5 ROWS_RANGE BETWEEN 3 PRECEDING AND CURRENT ROW)," - "w3 AS (PARTITION BY col1 ORDER BY col5 ROWS_RANGE BETWEEN 3 PRECEDING AND CURRENT ROW);"; + R"(SELECT + col1, + sum(col2) OVER w1, + col2+1, + add(col2, col1), + count(col2) OVER w1, + sum(col2) over w2 as w1_col2_sum , + sum(col2) over w3, + count_where(col0, col1 > 1) over w1 as cw1, + count_where(*, col5 = 0) over w1 as cw2, + FROM t1 + WINDOW w1 AS (PARTITION BY col1 ORDER BY col5 ROWS_RANGE BETWEEN 3m PRECEDING AND CURRENT ROW), + w2 AS (PARTITION BY col1,col2 ORDER BY col5 ROWS_RANGE BETWEEN 3 PRECEDING AND CURRENT ROW), + w3 AS (PARTITION BY col1 ORDER BY col5 ROWS_RANGE BETWEEN 3 PRECEDING AND CURRENT ROW);)"; const std::string expected = - "SIMPLE_PROJECT(sources=(col1, sum(col2)over w1, col2 + 1, add(col2, col1), count(col2)over w1, w1_col2_sum, " - "sum(col2)over w3))\n" - " REQUEST_JOIN(type=kJoinTypeConcat)\n" - " REQUEST_JOIN(type=kJoinTypeConcat)\n" - " REQUEST_JOIN(type=kJoinTypeConcat)\n" - " PROJECT(type=RowProject)\n" - " DATA_PROVIDER(request=t1)\n" - " SIMPLE_PROJECT(sources=(sum(col2)over w1, count(col2)over w1))\n" - " REQUEST_JOIN(type=kJoinTypeConcat)\n" - " PROJECT(type=ReduceAggregation: sum(col2)over w1 (range[180000 PRECEDING,0 CURRENT]))\n" - " REQUEST_AGG_UNION(partition_keys=(), orders=(ASC), range=(col5, 180000 PRECEDING, 0 CURRENT), " - "index_keys=(col1))\n" - " DATA_PROVIDER(request=t1)\n" - " DATA_PROVIDER(type=Partition, table=t1, index=index1)\n" - " DATA_PROVIDER(type=Partition, table=aggr_t1, index=index1_t2)\n" - " PROJECT(type=ReduceAggregation: count(col2)over w1 (range[180000 PRECEDING,0 CURRENT]))\n" - " REQUEST_AGG_UNION(partition_keys=(), orders=(ASC), range=(col5, 180000 PRECEDING, 0 CURRENT), " - "index_keys=(col1))\n" - " DATA_PROVIDER(request=t1)\n" - " DATA_PROVIDER(type=Partition, table=t1, index=index1)\n" - " DATA_PROVIDER(type=Partition, table=aggr_t1, index=index1_t2)\n" - " PROJECT(type=ReduceAggregation: sum(col2)over w2 (range[3 PRECEDING,0 CURRENT]))\n" - " REQUEST_AGG_UNION(partition_keys=(), orders=(ASC), range=(col5, 3 PRECEDING, 0 CURRENT), " - "index_keys=(col1,col2))\n" - " DATA_PROVIDER(request=t1)\n" - " DATA_PROVIDER(type=Partition, table=t1, index=index12)\n" - " DATA_PROVIDER(type=Partition, table=aggr_t1, index=index1_t2)\n" - " PROJECT(type=Aggregation)\n" - " REQUEST_UNION(partition_keys=(), orders=(ASC), range=(col5, 3 PRECEDING, 0 CURRENT), " - "index_keys=(col1))\n" - " DATA_PROVIDER(request=t1)\n" - " DATA_PROVIDER(type=Partition, table=t1, index=index1)"; + R"(SIMPLE_PROJECT(sources=(col1, sum(col2)over w1, col2 + 1, add(col2, col1), count(col2)over w1, w1_col2_sum, sum(col2)over w3, cw1, cw2)) + REQUEST_JOIN(type=kJoinTypeConcat) + REQUEST_JOIN(type=kJoinTypeConcat) + REQUEST_JOIN(type=kJoinTypeConcat) + PROJECT(type=RowProject) + DATA_PROVIDER(request=t1) + SIMPLE_PROJECT(sources=(sum(col2)over w1, count(col2)over w1, cw1, cw2)) + REQUEST_JOIN(type=kJoinTypeConcat) + REQUEST_JOIN(type=kJoinTypeConcat) + REQUEST_JOIN(type=kJoinTypeConcat) + PROJECT(type=ReduceAggregation: sum(col2)over w1 (range[180000 PRECEDING,0 CURRENT])) + REQUEST_AGG_UNION(partition_keys=(), orders=(ASC), range=(col5, 180000 PRECEDING, 0 CURRENT), index_keys=(col1)) + DATA_PROVIDER(request=t1) + DATA_PROVIDER(type=Partition, table=t1, index=index1) + DATA_PROVIDER(type=Partition, table=aggr_t1, index=index1_t2) + PROJECT(type=ReduceAggregation: count(col2)over w1 (range[180000 PRECEDING,0 CURRENT])) + REQUEST_AGG_UNION(partition_keys=(), orders=(ASC), range=(col5, 180000 PRECEDING, 0 CURRENT), index_keys=(col1)) + DATA_PROVIDER(request=t1) + DATA_PROVIDER(type=Partition, table=t1, index=index1) + DATA_PROVIDER(type=Partition, table=aggr_t1, index=index1_t2) + PROJECT(type=ReduceAggregation: count_where(col0, col1 > 1)over w1 (range[180000 PRECEDING,0 CURRENT])) + REQUEST_AGG_UNION(partition_keys=(), orders=(ASC), range=(col5, 180000 PRECEDING, 0 CURRENT), index_keys=(col1)) + DATA_PROVIDER(request=t1) + DATA_PROVIDER(type=Partition, table=t1, index=index1) + DATA_PROVIDER(type=Partition, table=aggr_t1, index=index1_t2) + PROJECT(type=ReduceAggregation: count_where(*, col5 = 0)over w1 (range[180000 PRECEDING,0 CURRENT])) + REQUEST_AGG_UNION(partition_keys=(), orders=(ASC), range=(col5, 180000 PRECEDING, 0 CURRENT), index_keys=(col1)) + DATA_PROVIDER(request=t1) + DATA_PROVIDER(type=Partition, table=t1, index=index1) + DATA_PROVIDER(type=Partition, table=aggr_t1, index=index1_t2) + PROJECT(type=ReduceAggregation: sum(col2)over w2 (range[3 PRECEDING,0 CURRENT])) + REQUEST_AGG_UNION(partition_keys=(), orders=(ASC), range=(col5, 3 PRECEDING, 0 CURRENT), index_keys=(col1,col2)) + DATA_PROVIDER(request=t1) + DATA_PROVIDER(type=Partition, table=t1, index=index12) + DATA_PROVIDER(type=Partition, table=aggr_t1, index=index1_t2) + PROJECT(type=Aggregation) + REQUEST_UNION(partition_keys=(), orders=(ASC), range=(col5, 3 PRECEDING, 0 CURRENT), index_keys=(col1)) + DATA_PROVIDER(request=t1) + DATA_PROVIDER(type=Partition, table=t1, index=index1))"; std::shared_ptr catalog(new SimpleCatalog(true)); hybridse::type::TableDef table_def; diff --git a/src/base/ddl_parser.cc b/src/base/ddl_parser.cc index 7afcc721a88..2b135467edb 100644 --- a/src/base/ddl_parser.cc +++ b/src/base/ddl_parser.cc @@ -22,6 +22,7 @@ #include #include +#include "absl/strings/match.h" #include "codec/schema_codec.h" #include "common/timer.h" #include "node/node_manager.h" @@ -219,11 +220,13 @@ hybridse::sdk::Status DDLParser::ExtractLongWindowInfos(const std::string& sql, if (0 != sql_status.code) { DLOG(ERROR) << sql_status.msg; - return hybridse::sdk::Status(base::ReturnCode::kError, sql_status.msg); + return hybridse::sdk::Status(base::ReturnCode::kError, sql_status.msg, sql_status.GetTraces()); } + hybridse::node::PlanNode* node = plan_trees[0]; switch (node->GetType()) { case hybridse::node::kPlanTypeQuery: { + // TODO(ace): Traverse Node return Status if (!TraverseNode(node, window_map, infos)) { return hybridse::sdk::Status(base::ReturnCode::kError, "TraverseNode failed"); } diff --git a/src/base/ddl_parser_test.cc b/src/base/ddl_parser_test.cc index 0d47aaefd32..20188665096 100644 --- a/src/base/ddl_parser_test.cc +++ b/src/base/ddl_parser_test.cc @@ -660,7 +660,7 @@ TEST_F(DDLParserTest, extractLongWindow) { "ROWS BETWEEN 2 PRECEDING AND CURRENT ROW);"; std::unordered_map window_map; - window_map["w1"] = "1000"; + window_map["w1"] = "1s"; openmldb::base::LongWindowInfos window_infos; auto extract_status = DDLParser::ExtractLongWindowInfos(query, window_map, &window_infos); ASSERT_TRUE(extract_status.IsOK()); @@ -670,7 +670,7 @@ TEST_F(DDLParserTest, extractLongWindow) { ASSERT_EQ(window_infos[0].aggr_col_, "c3"); ASSERT_EQ(window_infos[0].partition_col_, "c1"); ASSERT_EQ(window_infos[0].order_col_, "c6"); - ASSERT_EQ(window_infos[0].bucket_size_, "1000"); + ASSERT_EQ(window_infos[0].bucket_size_, "1s"); ASSERT_EQ(window_infos[0].filter_col_, "c1"); } diff --git a/src/catalog/tablet_catalog.cc b/src/catalog/tablet_catalog.cc index 8b8a4686606..1117bf97971 100644 --- a/src/catalog/tablet_catalog.cc +++ b/src/catalog/tablet_catalog.cc @@ -532,13 +532,10 @@ const Procedures& TabletCatalog::GetProcedures() { } std::vector<::hybridse::vm::AggrTableInfo> TabletCatalog::GetAggrTables( - const std::string& base_db, - const std::string& base_table, - const std::string& aggr_func, - const std::string& aggr_col, - const std::string& partition_cols, - const std::string& order_col) { - AggrTableKey key{base_db, base_table, aggr_func, aggr_col, partition_cols, order_col}; + const std::string& base_db, const std::string& base_table, const std::string& aggr_func, + const std::string& aggr_col, const std::string& partition_cols, const std::string& order_col, + const std::string& filter_col) { + AggrTableKey key{base_db, base_table, aggr_func, aggr_col, partition_cols, order_col, filter_col}; auto aggr_tables = std::atomic_load_explicit(&aggr_tables_, std::memory_order_acquire); return (*aggr_tables)[key]; } @@ -547,12 +544,12 @@ void TabletCatalog::RefreshAggrTables(const std::vector<::hybridse::vm::AggrTabl auto new_aggr_tables = std::make_shared(); for (const auto& table_info : table_infos) { // TODO(zhanghao): can use AggrTableKey *table_key = static_cast(&table_info); - AggrTableKey table_key{table_info.base_db, table_info.base_table, - table_info.aggr_func, table_info.aggr_col, - table_info.partition_cols, table_info.order_by_col}; + AggrTableKey table_key{table_info.base_db, table_info.base_table, table_info.aggr_func, + table_info.aggr_col, table_info.partition_cols, table_info.order_by_col, + table_info.filter_col}; if (new_aggr_tables->count(table_key) == 0) { new_aggr_tables->emplace(std::move(table_key), - std::vector<::hybridse::vm::AggrTableInfo>{std::move(table_info)}); + std::vector<::hybridse::vm::AggrTableInfo>{std::move(table_info)}); } else { new_aggr_tables->at(table_key).push_back(std::move(table_info)); } diff --git a/src/catalog/tablet_catalog.h b/src/catalog/tablet_catalog.h index eca85ac644f..c032921c582 100644 --- a/src/catalog/tablet_catalog.h +++ b/src/catalog/tablet_catalog.h @@ -255,13 +255,11 @@ class TabletCatalog : public ::hybridse::vm::Catalog { const Procedures &GetProcedures(); - std::vector<::hybridse::vm::AggrTableInfo> GetAggrTables( - const std::string& base_db, - const std::string& base_table, - const std::string& aggr_func, - const std::string& aggr_col, - const std::string& partition_cols, - const std::string& order_col) override; + std::vector<::hybridse::vm::AggrTableInfo> GetAggrTables(const std::string &base_db, const std::string &base_table, + const std::string &aggr_func, const std::string &aggr_col, + const std::string &partition_cols, + const std::string &order_col, + const std::string &filter_col) override; void RefreshAggrTables(const std::vector<::hybridse::vm::AggrTableInfo>& entries); @@ -273,12 +271,13 @@ class TabletCatalog : public ::hybridse::vm::Catalog { std::string aggr_col; std::string partition_cols; std::string order_by_col; + std::string filter_col; }; struct AggrTableKeyHash { std::size_t operator()(const AggrTableKey& key) const { - return std::hash()(key.base_db + key.base_table + key.aggr_func + - key.aggr_col + key.partition_cols + key.order_by_col); + return std::hash()(key.base_db + key.base_table + key.aggr_func + key.aggr_col + + key.partition_cols + key.order_by_col + key.filter_col); } }; @@ -289,7 +288,8 @@ class TabletCatalog : public ::hybridse::vm::Catalog { lhs.aggr_func == rhs.aggr_func && lhs.aggr_col == rhs.aggr_col && lhs.partition_cols == rhs.partition_cols && - lhs.order_by_col == rhs.order_by_col; + lhs.order_by_col == rhs.order_by_col && + lhs.filter_col == rhs.filter_col; } }; diff --git a/src/catalog/tablet_catalog_test.cc b/src/catalog/tablet_catalog_test.cc index 7a6f17b6d14..1f134028c60 100644 --- a/src/catalog/tablet_catalog_test.cc +++ b/src/catalog/tablet_catalog_test.cc @@ -704,20 +704,20 @@ TEST_F(TabletCatalogTest, aggr_table_test) { infos.push_back(info3); catalog->RefreshAggrTables(infos); - auto res = catalog->GetAggrTables("base_db", "base_t1", "sum", "col1", "col2", "col3"); + auto res = catalog->GetAggrTables("base_db", "base_t1", "sum", "col1", "col2", "col3", ""); ASSERT_EQ(2, res.size()); ASSERT_EQ(info1, res[0]); ASSERT_EQ(info2, res[1]); - res = catalog->GetAggrTables("base_db", "base_t1", "avg", "col1", "col2,col4", "col3"); + res = catalog->GetAggrTables("base_db", "base_t1", "avg", "col1", "col2,col4", "col3", ""); ASSERT_EQ(1, res.size()); ASSERT_EQ(info3, res[0]); - res = catalog->GetAggrTables("base_db", "base_t1", "count", "col1", "col2,col4", "col3"); + res = catalog->GetAggrTables("base_db", "base_t1", "count", "col1", "col2,col4", "col3", ""); ASSERT_EQ(0, res.size()); } -TEST_F(TabletCatalogTest, long_window_smoke_test) { +TEST_F(TabletCatalogTest, LongWindowSmokeTest) { std::shared_ptr catalog(new TabletCatalog()); ASSERT_TRUE(catalog->Init()); int num_pk = 2, num_ts = 9, bucket_size = 2; @@ -728,8 +728,8 @@ TEST_F(TabletCatalogTest, long_window_smoke_test) { TestArgs args2 = PrepareAggTable("aggr_t1", num_pk, num_ts, bucket_size, 1); ASSERT_TRUE(catalog->AddTable(args2.meta[0], args2.tables[0])); - ::hybridse::vm::AggrTableInfo info1 = {"aggr_t1", "aggr_db", "db1", "t1", - "sum", "col2", "col1", "col2", "2"}; + ::hybridse::vm::AggrTableInfo info1 = {"aggr_t1", "aggr_db", "db1", "t1", "sum", "col2", "col1", "col2", "2", ""}; + catalog->RefreshAggrTables({info1}); ::hybridse::vm::Engine engine(catalog); diff --git a/src/cmd/sql_cmd_test.cc b/src/cmd/sql_cmd_test.cc index 8bcb680dd5f..4912315de9b 100644 --- a/src/cmd/sql_cmd_test.cc +++ b/src/cmd/sql_cmd_test.cc @@ -24,9 +24,11 @@ #include #include +#include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" #include "absl/random/random.h" #include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "gflags/gflags.h" @@ -586,7 +588,7 @@ TEST_P(DBSDKTest, DeployOptions) { " WINDOW w1 AS (PARTITION BY trans.c1 ORDER BY trans.c7 ROWS BETWEEN 2 PRECEDING AND CURRENT ROW);"; hybridse::sdk::Status status; sr->ExecuteSQL(deploy_sql, &status); - ASSERT_TRUE(status.IsOK()); + ASSERT_TRUE(status.IsOK()) << status.msg; std::string msg; auto ok = sr->ExecuteDDL(openmldb::nameserver::PRE_AGG_DB, "drop table pre_test2_demo_w1_sum_c4;", &status); ASSERT_TRUE(ok); @@ -698,6 +700,12 @@ void CreateDBTableForLongWindow(const std::string& base_db, const std::string& b ASSERT_EQ(tables.size(), 1) << msg; } +// ----------------------------------------------------------------------------------- +// col1 col2 col3 i64_col i16_col i32_col f_col d_col t_col s_col date_col filter +// str1 str2 i i i i i i i i 1900-01-i i % 2 +// +// where i in [1 .. 11] +// ----------------------------------------------------------------------------------- void PrepareDataForLongWindow(const std::string& base_db, const std::string& base_table) { ::hybridse::sdk::Status status; for (int i = 1; i <= 11; i++) { @@ -734,9 +742,124 @@ void PrepareRequestRowForLongWindow(const std::string& base_db, const std::strin ASSERT_TRUE(req->AppendTimestamp(11)); ASSERT_TRUE(req->AppendString("11")); ASSERT_TRUE(req->AppendDate(11)); + // filter = null + req->AppendNULL(); ASSERT_TRUE(req->Build()); } +// TODO(ace): create instance of DeployLongWindowEnv with template +class DeployLongWindowEnv { + public: + explicit DeployLongWindowEnv(sdk::SQLClusterRouter* sr) : sr_(sr) {} + + virtual ~DeployLongWindowEnv() {} + + void SetUp() { + db_ = absl::StrCat("db_", absl::Uniform(gen_, 0, std::numeric_limits::max())); + table_ = absl::StrCat("tb_", absl::Uniform(gen_, 0, std::numeric_limits::max())); + dp_ = absl::StrCat("dp_", absl::Uniform(gen_, 0, std::numeric_limits::max())); + + PrepareSchema(); + + ASSERT_TRUE(sr_->RefreshCatalog()); + + Deploy(); + + PrepareData(); + } + + void TearDown() { + TearDownPreAggTables(); + ProcessSQLs(sr_, { + absl::StrCat("drop table ", table_), + absl::StrCat("drop database ", db_), + }); + } + + void CallDeploy(std::shared_ptr* rs) { + hybridse::sdk::Status status; + std::shared_ptr rr = std::make_shared(); + GetRequestRow(&rr, dp_); + auto res = sr_->CallProcedure(db_, dp_, rr, &status); + ASSERT_TRUE(status.IsOK()) << status.msg << "\n" << status.trace; + *rs = std::move(res); + } + + private: + virtual void PrepareSchema() { + ProcessSQLs( + sr_, {"SET @@execute_mode='online';", + absl::StrCat("create database ", db_), + absl::StrCat("use ", db_), + absl::StrCat( + "create table ", table_, + "(col1 string, col2 string, col3 timestamp, i64_col bigint, i16_col smallint, i32_col int, f_col " + "float, d_col double, t_col timestamp, s_col string, date_col date, filter int, " + "index(key=(col1,col2), ts=col3, abs_ttl=0, ttl_type=absolute)) " + "options(partitionnum=8);") + }); + } + + virtual void PrepareData() { + // prepare data + // ----------------------------------------------------------------------------------- + // col1 col2 col3 i64_col i16_col i32_col f_col d_col t_col s_col date_col filter + // str1 str2 i * 1000 i i i i i i i 1900-01-i i % 2 + // + // where i in [1 .. 11] + // ----------------------------------------------------------------------------------- + for (int i = 1; i <= 11; i++) { + std::string val = std::to_string(i); + std::string filter_val = std::to_string(i % 2); + std::string date; + if (i < 10) { + date = absl::StrCat("1900-01-0", std::to_string(i)); + } else { + date = absl::StrCat("1900-01-", std::to_string(i)); + } + std::string insert = + absl::StrCat("insert into ", table_, " values('str1', 'str2', ", i * 1000, ", ", val, ", ", val, ", ", + val, ", ", val, ", ", val, ", ", val, ", '", val, "', '", date, "', ", filter_val, ");"); + ::hybridse::sdk::Status s; + bool ok = sr_->ExecuteInsert(db_, insert, &s); + ASSERT_TRUE(ok && s.IsOK()) << s.msg << "\n" << s.trace; + } + } + + virtual void Deploy() = 0; + + virtual void TearDownPreAggTables() = 0; + + void GetRequestRow(std::shared_ptr* rs, const std::string& name) { // NOLINT + ::hybridse::sdk::Status status; + auto req = sr_->GetRequestRowByProcedure(db_, dp_, &status); + ASSERT_TRUE(status.IsOK()); + ASSERT_TRUE(req->Init(strlen("str1") + strlen("str2") + strlen("11"))); + ASSERT_TRUE(req->AppendString("str1")); + ASSERT_TRUE(req->AppendString("str2")); + ASSERT_TRUE(req->AppendTimestamp(11000)); + ASSERT_TRUE(req->AppendInt64(11)); + ASSERT_TRUE(req->AppendInt16(11)); + ASSERT_TRUE(req->AppendInt32(11)); + ASSERT_TRUE(req->AppendFloat(11)); + ASSERT_TRUE(req->AppendDouble(11)); + ASSERT_TRUE(req->AppendTimestamp(11)); + ASSERT_TRUE(req->AppendString("11")); + ASSERT_TRUE(req->AppendDate(11)); + // filter = null + req->AppendNULL(); + ASSERT_TRUE(req->Build()); + *rs = std::move(req); + } + + protected: + sdk::SQLClusterRouter* sr_; + absl::BitGen gen_; + std::string db_; + std::string table_; + std::string dp_; +}; + TEST_P(DBSDKTest, DeployLongWindowsEmpty) { auto cli = GetParam(); cs = cli->cs; @@ -1504,22 +1627,22 @@ TEST_P(DBSDKTest, DeployLongWindowsExecuteCount) { LOG(WARNING) << "Before CallProcedure"; auto res = sr->CallProcedure(base_db, "test_aggr", req, &status); LOG(WARNING) << "After CallProcedure"; - ASSERT_TRUE(status.IsOK()); - ASSERT_EQ(1, res->Size()); - ASSERT_TRUE(res->Next()); - ASSERT_EQ("str1", res->GetStringUnsafe(0)); - ASSERT_EQ("str2", res->GetStringUnsafe(1)); + EXPECT_TRUE(status.IsOK()); + EXPECT_EQ(1, res->Size()); + EXPECT_TRUE(res->Next()); + EXPECT_EQ("str1", res->GetStringUnsafe(0)); + EXPECT_EQ("str2", res->GetStringUnsafe(1)); int64_t exp = 7; - ASSERT_EQ(exp, res->GetInt64Unsafe(2)); - ASSERT_EQ(exp, res->GetInt64Unsafe(3)); - ASSERT_EQ(exp, res->GetInt64Unsafe(4)); - ASSERT_EQ(exp, res->GetInt64Unsafe(5)); - ASSERT_EQ(exp, res->GetInt64Unsafe(6)); - ASSERT_EQ(exp, res->GetInt64Unsafe(7)); - ASSERT_EQ(exp, res->GetInt64Unsafe(8)); - ASSERT_EQ(exp, res->GetInt64Unsafe(9)); - ASSERT_EQ(exp, res->GetInt64Unsafe(10)); - ASSERT_EQ(exp, res->GetInt64Unsafe(11)); + EXPECT_EQ(exp, res->GetInt64Unsafe(2)); + EXPECT_EQ(exp, res->GetInt64Unsafe(3)); + EXPECT_EQ(exp, res->GetInt64Unsafe(4)); + EXPECT_EQ(exp, res->GetInt64Unsafe(5)); + EXPECT_EQ(exp, res->GetInt64Unsafe(6)); + EXPECT_EQ(exp, res->GetInt64Unsafe(7)); + EXPECT_EQ(exp, res->GetInt64Unsafe(8)); + EXPECT_EQ(exp, res->GetInt64Unsafe(9)); + EXPECT_EQ(exp, res->GetInt64Unsafe(10)); + EXPECT_EQ(exp, res->GetInt64Unsafe(11)); } ASSERT_TRUE(cs->GetNsClient()->DropProcedure(base_db, "test_aggr", msg)); @@ -1557,6 +1680,8 @@ TEST_P(DBSDKTest, DeployLongWindowsExecuteCount) { } TEST_P(DBSDKTest, DeployLongWindowsExecuteCountWhere) { + GTEST_SKIP() << "count_where for rows window un-supported due to pre-agg rows not aligned"; + auto cli = GetParam(); cs = cli->cs; sr = cli->sr; @@ -1568,22 +1693,26 @@ TEST_P(DBSDKTest, DeployLongWindowsExecuteCountWhere) { std::string msg; CreateDBTableForLongWindow(base_db, base_table); - std::string deploy_sql = "deploy test_aggr options(long_windows='w1:2') select col1, col2," - " count_where(i64_col, filter<1) over w1 as w1_count_where_i64_col_filter," - " count_where(i64_col, col1='str1') over w1 as w1_count_where_i64_col_col1," - " count_where(i16_col, filter>1) over w1 as w1_count_where_i16_col," - " count_where(i32_col, 1=filter) over w1 as w1_count_where_t_col," - " count_where(s_col, 2filter) over w1 as w1_count_where_date_col," - " count_where(col3, 0>=filter) over w2 as w2_count_where_col3" - " from " + base_table + - " WINDOW w1 AS (PARTITION BY col1,col2 ORDER BY col3" - " ROWS_RANGE BETWEEN 5 PRECEDING AND CURRENT ROW), " - " w2 AS (PARTITION BY col1,col2 ORDER BY i64_col" - " ROWS BETWEEN 6 PRECEDING AND CURRENT ROW);"; + std::string deploy_sql = + R"(DEPLOY test_aggr options(long_windows='w1:2') + SELECT + col1, col2, + count_where(i64_col, filter<1) over w1 as w1_count_where_i64_col_filter, + count_where(i64_col, col1='str1') over w1 as w1_count_where_i64_col_col1, + count_where(i16_col, filter>1) over w1 as w1_count_where_i16_col, + count_where(i32_col, 1=filter) over w1 as w1_count_where_t_col, + count_where(s_col, 2filter) over w1 as w1_count_where_date_col, + count_where(col3, 0>=filter) over w2 as w2_count_where_col3 from )" + + base_table + + R"( + WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 5 PRECEDING AND CURRENT ROW), + w2 AS (PARTITION BY col1,col2 ORDER BY i64_col ROWS BETWEEN 6 PRECEDING AND CURRENT ROW);)"; + sr->ExecuteSQL(base_db, "use " + base_db + ";", &status); ASSERT_TRUE(status.IsOK()) << status.msg; sr->ExecuteSQL(base_db, deploy_sql, &status); @@ -1647,6 +1776,30 @@ TEST_P(DBSDKTest, DeployLongWindowsExecuteCountWhere) { rs = sr->ExecuteSQL(pre_aggr_db, result_sql, &status); ASSERT_EQ(4, rs->Size()); + // 11, 11, 10, 9, 8, 7, 6 + for (int i = 0; i < 2; i++) { + std::shared_ptr req; + PrepareRequestRowForLongWindow(base_db, "test_aggr", req); + DLOG(INFO) << "Before CallProcedure"; + auto res = sr->CallProcedure(base_db, "test_aggr", req, &status); + DLOG(INFO) << "After CallProcedure"; + EXPECT_TRUE(status.IsOK()); + EXPECT_EQ(1, res->Size()); + EXPECT_TRUE(res->Next()); + EXPECT_EQ("str1", res->GetStringUnsafe(0)); + EXPECT_EQ("str2", res->GetStringUnsafe(1)); + EXPECT_EQ(3, res->GetInt64Unsafe(2)); + EXPECT_EQ(7, res->GetInt64Unsafe(3)); + EXPECT_EQ(0, res->GetInt64Unsafe(4)); + EXPECT_EQ(0, res->GetInt64Unsafe(5)); + EXPECT_EQ(3, res->GetInt64Unsafe(6)); + EXPECT_EQ(3, res->GetInt64Unsafe(7)); + EXPECT_EQ(6, res->GetInt64Unsafe(8)); + EXPECT_EQ(0, res->GetInt64Unsafe(9)); + EXPECT_EQ(6, res->GetInt64Unsafe(10)); + EXPECT_EQ(3, res->GetInt64Unsafe(11)); + } + ASSERT_TRUE(cs->GetNsClient()->DropProcedure(base_db, "test_aggr", msg)); pre_aggr_table = "pre_" + base_db + "_test_aggr_w1_count_where_i64_col_filter"; ok = sr->ExecuteDDL(pre_aggr_db, "drop table " + pre_aggr_table + ";", &status); @@ -1681,6 +1834,656 @@ TEST_P(DBSDKTest, DeployLongWindowsExecuteCountWhere) { ASSERT_TRUE(ok); } +// pre agg rows is range buckets +TEST_P(DBSDKTest, DeployLongWindowsExecuteCountWhere2) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + + class DeployLongWindowCountWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowCountWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowCountWhereEnv() override {} + + void Deploy() override { + ProcessSQLs(sr_, {absl::Substitute(R"(DEPLOY $0 options(long_windows='w1:2s') + SELECT + col1, col2, + count_where(i64_col, i64_col<8) over w1 as cw_w1_2, + count_where(i64_col, i16_col > 8) over w1 as cw_w1_3, + count_where(i16_col, i32_col = 10) over w1 as cw_w1_4, + count_where(i32_col, f_col != 10) over w1 as cw_w1_5, + count_where(f_col, d_col <= 10) over w1 as cw_w1_6, + count_where(d_col, d_col >= 10) over w1 as cw_w1_7, + count_where(s_col, null = col1) over w1 as cw_w1_8, + count_where(s_col, 'str0' != col1) over w1 as cw_w1_9, + count_where(date_col, null != s_col) over w1 as cw_w1_10, + count_where(*, i64_col > 0) over w1 as cw_w1_11, + count_where(filter, i64_col > 0) over w1 as cw_w1_12, + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 6s PRECEDING AND CURRENT ROW);)", + dp_, table_)}); + } + + void TearDownPreAggTables() override { + absl::string_view pre_agg_db = openmldb::nameserver::PRE_AGG_DB; + ProcessSQLs(sr_, { + absl::StrCat("use ", pre_agg_db), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_i64_col_i64_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_i64_col_i16_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_i16_col_i32_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_i32_col_f_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_f_col_d_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_d_col_d_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_s_col_col1"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_date_col_s_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where__i64_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_filter_i64_col"), + absl::StrCat("use ", db_), + absl::StrCat("drop deployment ", dp_), + }); + } + }; + + // request window [5s, 11s] + DeployLongWindowCountWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; + + std::shared_ptr res; + env.CallDeploy(&res); + ASSERT_TRUE(res != nullptr) << "call deploy failed"; + + EXPECT_EQ(1, res->Size()); + EXPECT_TRUE(res->Next()); + EXPECT_EQ("str1", res->GetStringUnsafe(0)); + EXPECT_EQ("str2", res->GetStringUnsafe(1)); + EXPECT_EQ(3, res->GetInt64Unsafe(2)); + EXPECT_EQ(4, res->GetInt64Unsafe(3)); + EXPECT_EQ(1, res->GetInt64Unsafe(4)); + EXPECT_EQ(7, res->GetInt64Unsafe(5)); + EXPECT_EQ(6, res->GetInt64Unsafe(6)); + EXPECT_EQ(3, res->GetInt64Unsafe(7)); + EXPECT_EQ(0, res->GetInt64Unsafe(8)); + EXPECT_EQ(8, res->GetInt64Unsafe(9)); + EXPECT_EQ(0, res->GetInt64Unsafe(10)); + EXPECT_EQ(8, res->GetInt64Unsafe(11)); + EXPECT_EQ(7, res->GetInt64Unsafe(12)); +} + +// pre agg rows is range buckets +TEST_P(DBSDKTest, DeployLongWindowsExecuteCountWhere3) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + + class DeployLongWindowCountWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowCountWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowCountWhereEnv() override {} + + void Deploy() override { + ProcessSQLs(sr_, {absl::Substitute(R"(DEPLOY $0 options(long_windows='w1:3s') + SELECT + col1, col2, + count_where(i64_col, filter<1) over w1 as w1_count_where_i64_col_filter, + count_where(i64_col, col1='str1') over w1 as w1_count_where_i64_col_col1, + count_where(i16_col, filter>1) over w1 as w1_count_where_i16_col, + count_where(i32_col, 1=filter) over w1 as w1_count_where_t_col, + count_where(s_col, 2filter) over w1 as w1_count_where_date_col, + count_where(col3, 0>=filter) over w2 as w2_count_where_col3 + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 7s PRECEDING AND CURRENT ROW), + w2 AS (PARTITION BY col1,col2 ORDER BY i64_col ROWS BETWEEN 6 PRECEDING AND CURRENT ROW);)", + dp_, table_)}); + } + + void TearDownPreAggTables() override { + absl::string_view pre_agg_db = openmldb::nameserver::PRE_AGG_DB; + ProcessSQLs(sr_, { + absl::StrCat("use ", pre_agg_db), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_i64_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_i64_col_col1"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_i16_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_i32_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_f_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_d_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_t_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_s_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_count_where_date_col_filter"), + absl::StrCat("use ", db_), + absl::StrCat("drop deployment ", dp_), + }); + } + }; + + // request window [4s, 11s] + DeployLongWindowCountWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; + + std::shared_ptr res; + // ts 11, 11, 10, 9, 8, 7, 6, 5, 4 + env.CallDeploy(&res); + ASSERT_TRUE(res != nullptr) << "call deploy failed"; + + EXPECT_EQ(1, res->Size()); + EXPECT_TRUE(res->Next()); + EXPECT_EQ("str1", res->GetStringUnsafe(0)); + EXPECT_EQ("str2", res->GetStringUnsafe(1)); + EXPECT_EQ(4, res->GetInt64Unsafe(2)); + EXPECT_EQ(9, res->GetInt64Unsafe(3)); + EXPECT_EQ(0, res->GetInt64Unsafe(4)); + EXPECT_EQ(0, res->GetInt64Unsafe(5)); + EXPECT_EQ(4, res->GetInt64Unsafe(6)); + EXPECT_EQ(4, res->GetInt64Unsafe(7)); + EXPECT_EQ(8, res->GetInt64Unsafe(8)); + EXPECT_EQ(0, res->GetInt64Unsafe(9)); + EXPECT_EQ(8, res->GetInt64Unsafe(10)); + EXPECT_EQ(3, res->GetInt64Unsafe(11)); +} + +TEST_P(DBSDKTest, LongWindowMinMaxWhere) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + + class DeployLongWindowMinMaxWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowMinMaxWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowMinMaxWhereEnv() override {} + + void Deploy() override { + ProcessSQLs(sr_, {absl::Substitute(R"s(DEPLOY $0 options(long_windows='w1:3s') + SELECT + col1, col2, + max_where(i64_col, filter<1) over w1 as m1, + max_where(i64_col, col1='str1') over w1 as m2, + max_where(i16_col, filter>1) over w1 as m3, + max_where(i32_col, 1 8) over w1 as m7, + min_where(i16_col, i32_col = 10) over w1 as m8, + min_where(i32_col, f_col != 10) over w1 as m9, + min_where(f_col, d_col <= 10) over w1 as m10, + min_where(d_col, d_col >= 10) over w1 as m11, + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 7s PRECEDING AND CURRENT ROW))s", + dp_, table_)}); + } + + void TearDownPreAggTables() override { + absl::string_view pre_agg_db = openmldb::nameserver::PRE_AGG_DB; + ProcessSQLs(sr_, + { + absl::StrCat("use ", pre_agg_db), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_max_where_i64_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_max_where_i64_col_col1"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_max_where_i16_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_max_where_i32_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_max_where_f_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_max_where_d_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_min_where_i64_col_i16_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_min_where_i16_col_i32_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_min_where_i32_col_f_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_min_where_f_col_d_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_min_where_d_col_d_col"), + absl::StrCat("use ", db_), + absl::StrCat("drop deployment ", dp_), + }); + } + }; + + // request window [4s, 11s] + DeployLongWindowMinMaxWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; + + std::shared_ptr res; + // ts 11, 11, 10, 9, 8, 7, 6, 5, 4 + env.CallDeploy(&res); + ASSERT_TRUE(res != nullptr) << "call deploy failed"; + + EXPECT_EQ(1, res->Size()); + EXPECT_TRUE(res->Next()); + EXPECT_EQ("str1", res->GetStringUnsafe(0)); + EXPECT_EQ("str2", res->GetStringUnsafe(1)); + EXPECT_EQ(10, res->GetInt64Unsafe(2)); + EXPECT_EQ(11, res->GetInt64Unsafe(3)); + EXPECT_TRUE(res->IsNULL(4)); + EXPECT_TRUE(res->IsNULL(5)); + EXPECT_EQ(10.0, res->GetFloatUnsafe(6)); + EXPECT_EQ(11.0, res->GetDoubleUnsafe(7)); + EXPECT_EQ(9, res->GetInt64Unsafe(8)); + EXPECT_EQ(10, res->GetInt16Unsafe(9)); + EXPECT_EQ(4, res->GetInt32Unsafe(10)); + EXPECT_EQ(4.0, res->GetFloatUnsafe(11)); + EXPECT_EQ(10.0, res->GetDoubleUnsafe(12)); +} + +TEST_P(DBSDKTest, LongWindowSumWhere) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + + class DeployLongWindowSumWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowSumWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowSumWhereEnv() override {} + + void Deploy() override { + ProcessSQLs(sr_, {absl::Substitute(R"s(DEPLOY $0 options(long_windows='w1:4s') + SELECT + col1, col2, + sum_where(i64_col, col1='str1') over w1 as m1, + sum_where(i16_col, filter>1) over w1 as m2, + sum_where(i32_col, filter = null) over w1 as m3, + sum_where(f_col, 0=filter) over w1 as m4, + sum_where(d_col, 1=filter) over w1 as m5, + sum_where(i64_col, i16_col > 8) over w1 as m6, + sum_where(i16_col, i32_col = 10) over w1 as m7, + sum_where(i32_col, f_col != 10) over w1 as m8, + sum_where(f_col, d_col <= 10) over w1 as m9, + sum_where(d_col, d_col >= 10) over w1 as m10, + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 7s PRECEDING AND CURRENT ROW))s", + dp_, table_)}); + } + + void TearDownPreAggTables() override { + absl::string_view pre_agg_db = openmldb::nameserver::PRE_AGG_DB; + ProcessSQLs(sr_, { + absl::StrCat("use ", pre_agg_db), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_i64_col_col1"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_i16_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_i32_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_f_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_d_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_i64_col_i16_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_i16_col_i32_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_i32_col_f_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_f_col_d_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_sum_where_d_col_d_col"), + absl::StrCat("use ", db_), + absl::StrCat("drop deployment ", dp_), + }); + } + }; + + // request window [4s, 11s] + DeployLongWindowSumWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; + + std::shared_ptr res; + // ts 11, 11, 10, 9, 8, 7, 6, 5, 4 + env.CallDeploy(&res); + ASSERT_TRUE(res != nullptr) << "call deploy failed"; + + EXPECT_EQ(1, res->Size()); + EXPECT_TRUE(res->Next()); + EXPECT_EQ("str1", res->GetStringUnsafe(0)); + EXPECT_EQ("str2", res->GetStringUnsafe(1)); + EXPECT_EQ(71, res->GetInt64Unsafe(2)); + EXPECT_TRUE(res->IsNULL(3)); + EXPECT_TRUE(res->IsNULL(4)); + EXPECT_EQ(28.0, res->GetFloatUnsafe(5)); + EXPECT_EQ(32.0, res->GetDoubleUnsafe(6)); + EXPECT_EQ(41, res->GetInt64Unsafe(7)); + EXPECT_EQ(10, res->GetInt16Unsafe(8)); + EXPECT_EQ(61, res->GetInt32Unsafe(9)); + EXPECT_EQ(49.0, res->GetFloatUnsafe(10)); + EXPECT_EQ(32.0, res->GetDoubleUnsafe(11)); +} + +TEST_P(DBSDKTest, LongWindowAvgWhere) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + + class DeployLongWindowAvgWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowAvgWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowAvgWhereEnv() override {} + + void Deploy() override { + ProcessSQLs(sr_, {absl::Substitute(R"s(DEPLOY $0 options(long_windows='w1:3s') + SELECT + col1, col2, + avg_where(i64_col, col1!='str1') over w1 as m1, + avg_where(i16_col, filter<1) over w1 as m2, + avg_where(i32_col, filter = null) over w1 as m3, + avg_where(f_col, 0=filter) over w1 as m4, + avg_where(d_col, f_col = 11) over w1 as m5, + avg_where(i64_col, i16_col > 10) over w1 as m6, + avg_where(i16_col, i32_col = 10) over w1 as m7, + avg_where(i32_col, f_col != 7) over w1 as m8, + avg_where(f_col, d_col <= 10) over w1 as m9, + avg_where(d_col, d_col < 4.5) over w1 as m10, + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 7s PRECEDING AND CURRENT ROW))s", + dp_, table_)}); + } + + void TearDownPreAggTables() override { + absl::string_view pre_agg_db = openmldb::nameserver::PRE_AGG_DB; + ProcessSQLs(sr_, { + absl::StrCat("use ", pre_agg_db), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i64_col_col1"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i16_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i32_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_f_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_d_col_f_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i64_col_i16_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i16_col_i32_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i32_col_f_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_f_col_d_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_d_col_d_col"), + absl::StrCat("use ", db_), + absl::StrCat("drop deployment ", dp_), + }); + } + }; + + // request window [4s, 11s] + DeployLongWindowAvgWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; + + std::shared_ptr res; + // ts 11, 11, 10, 9, 8, 7, 6, 5, 4 + env.CallDeploy(&res); + ASSERT_TRUE(res != nullptr) << "call deploy failed"; + + EXPECT_EQ(1, res->Size()); + EXPECT_TRUE(res->Next()); + EXPECT_EQ("str1", res->GetStringUnsafe(0)); + EXPECT_EQ("str2", res->GetStringUnsafe(1)); + EXPECT_TRUE(res->IsNULL(2)); + EXPECT_EQ(7.0, res->GetDoubleUnsafe(3)); + EXPECT_TRUE(res->IsNULL(4)); + EXPECT_EQ(7.0, res->GetDoubleUnsafe(5)); + EXPECT_EQ(11.0, res->GetDoubleUnsafe(6)); + EXPECT_EQ(11.0, res->GetDoubleUnsafe(7)); + EXPECT_EQ(10.0, res->GetDoubleUnsafe(8)); + EXPECT_EQ(8.0, res->GetDoubleUnsafe(9)); + EXPECT_EQ(7.0, res->GetDoubleUnsafe(10)); + EXPECT_EQ(4.0, res->GetDoubleUnsafe(11)); +} + +TEST_P(DBSDKTest, LongWindowAnyWhereWithDataOutOfOrder) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + + class DeployLongWindowAnyWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowAnyWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowAnyWhereEnv() override {} + + void Deploy() override { + ProcessSQLs(sr_, {absl::Substitute(R"s(DEPLOY $0 options(long_windows='w1:3s') + SELECT + col1, col2, + avg_where(i64_col, col1!='str1') over w1 as m1, + avg_where(i16_col, filter<1) over w1 as m2, + avg_where(i32_col, filter = null) over w1 as m3, + avg_where(f_col, 0=filter) over w1 as m4, + avg_where(d_col, f_col = 11) over w1 as m5, + avg_where(i64_col, i16_col > 10) over w1 as m6, + avg_where(i16_col, i32_col = 10) over w1 as m7, + avg_where(i32_col, f_col != 7) over w1 as m8, + avg_where(f_col, d_col <= 10) over w1 as m9, + avg_where(d_col, d_col < 4.5) over w1 as m10, + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 7s PRECEDING AND CURRENT ROW))s", + dp_, table_)}); + } + + void PrepareData() override { + std::vector order = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + absl::BitGen gen; + absl::c_shuffle(order, gen); + + for (auto i : order) { + std::string val = std::to_string(i); + std::string filter_val = std::to_string(i % 2); + std::string date; + if (i < 10) { + date = absl::StrCat("1900-01-0", std::to_string(i)); + } else { + date = absl::StrCat("1900-01-", std::to_string(i)); + } + std::string insert = absl::StrCat("insert into ", table_, " values('str1', 'str2', ", i * 1000, ", ", + val, ", ", val, ", ", val, ", ", val, ", ", val, ", ", val, ", '", + val, "', '", date, "', ", filter_val, ");"); + ::hybridse::sdk::Status s; + bool ok = sr_->ExecuteInsert(db_, insert, &s); + ASSERT_TRUE(ok && s.IsOK()) << s.msg << "\n" << s.trace; + } + } + + void TearDownPreAggTables() override { + absl::string_view pre_agg_db = openmldb::nameserver::PRE_AGG_DB; + ProcessSQLs(sr_, { + absl::StrCat("use ", pre_agg_db), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i64_col_col1"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i16_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i32_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_f_col_filter"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_d_col_f_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i64_col_i16_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i16_col_i32_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_i32_col_f_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_f_col_d_col"), + absl::StrCat("drop table pre_", db_, "_", dp_, "_w1_avg_where_d_col_d_col"), + absl::StrCat("use ", db_), + absl::StrCat("drop deployment ", dp_), + }); + } + }; + + // request window [4s, 11s] + DeployLongWindowAnyWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; + + std::shared_ptr res; + // ts 11, 11, 10, 9, 8, 7, 6, 5, 4 + env.CallDeploy(&res); + ASSERT_TRUE(res != nullptr) << "call deploy failed"; + + EXPECT_EQ(1, res->Size()); + EXPECT_TRUE(res->Next()); + EXPECT_EQ("str1", res->GetStringUnsafe(0)); + EXPECT_EQ("str2", res->GetStringUnsafe(1)); + EXPECT_TRUE(res->IsNULL(2)); + EXPECT_EQ(7.0, res->GetDoubleUnsafe(3)); + EXPECT_TRUE(res->IsNULL(4)); + EXPECT_EQ(7.0, res->GetDoubleUnsafe(5)); + EXPECT_EQ(11.0, res->GetDoubleUnsafe(6)); + EXPECT_EQ(11.0, res->GetDoubleUnsafe(7)); + EXPECT_EQ(10.0, res->GetDoubleUnsafe(8)); + EXPECT_EQ(8.0, res->GetDoubleUnsafe(9)); + EXPECT_EQ(7.0, res->GetDoubleUnsafe(10)); + EXPECT_EQ(4.0, res->GetDoubleUnsafe(11)); +} + +TEST_P(DBSDKTest, LongWindowAnyWhereUnsupportRowsBucket) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + + class DeployLongWindowAnyWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowAnyWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowAnyWhereEnv() override {} + + void Deploy() override { + hybridse::sdk::Status status; + sr_->ExecuteSQL(absl::Substitute(R"s(DEPLOY $0 options(long_windows='w1:3') + SELECT + col1, col2, + avg_where(i64_col, col1!='str1') over w1 as m1, + avg_where(i16_col, filter<1) over w1 as m2, + avg_where(i32_col, filter = null) over w1 as m3, + avg_where(f_col, 0=filter) over w1 as m4, + avg_where(d_col, f_col = 11) over w1 as m5, + avg_where(i64_col, i16_col > 10) over w1 as m6, + avg_where(i16_col, i32_col = 10) over w1 as m7, + avg_where(i32_col, f_col != 7) over w1 as m8, + avg_where(f_col, d_col <= 10) over w1 as m9, + avg_where(d_col, d_col < 4.5) over w1 as m10, + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 7s PRECEDING AND CURRENT ROW))s", + dp_, table_), + &status); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.msg, "unsupport *_where op (avg_where) for rows bucket type long window") + << "code=" << status.code << ", msg=" << status.msg << "\n" + << status.trace; + } + + void TearDownPreAggTables() override {} + }; + + // unsupport: deploy any_where with rows bucket + DeployLongWindowAnyWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; +} + +TEST_P(DBSDKTest, LongWindowAnyWhereUnsupportTimeFilter) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + + { + class DeployLongWindowAnyWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowAnyWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowAnyWhereEnv() override {} + + void Deploy() override { + hybridse::sdk::Status status; + sr_->ExecuteSQL(absl::Substitute(R"s(DEPLOY $0 options(long_windows='w1:3s') + SELECT + col1, col2, + min_where(i64_col, date_col!="2012-12-12") over w1 as m1, + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 7s PRECEDING AND CURRENT ROW))s", + dp_, table_), + &status); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.msg, "unsupport date or timestamp as filer column (date_col)") + << "code=" << status.code << ", msg=" << status.msg << "\n" + << status.trace; + } + + void TearDownPreAggTables() override {} + }; + + DeployLongWindowAnyWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; + } + + { + class DeployLongWindowAnyWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowAnyWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowAnyWhereEnv() override {} + + void Deploy() override { + hybridse::sdk::Status status; + sr_->ExecuteSQL(absl::Substitute(R"s(DEPLOY $0 options(long_windows='w1:3s') + SELECT + col1, col2, + count_where(i64_col, t_col!="2012-12-12") over w1 as m1, + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 7s PRECEDING AND CURRENT ROW))s", + dp_, table_), + &status); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.msg, "unsupport date or timestamp as filer column (t_col)") + << "code=" << status.code << ", msg=" << status.msg << "\n" + << status.trace; + } + + void TearDownPreAggTables() override {} + }; + + DeployLongWindowAnyWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; + } +} + +TEST_P(DBSDKTest, LongWindowAnyWhereUnsupportHDDTable) { + // *_where over HDD/SSD table main table not support + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + + if (cs->IsClusterMode()) { + GTEST_SKIP() << "cluster mode skiped because it use same hdd path with standalone mode"; + } + + class DeployLongWindowAnyWhereEnv : public DeployLongWindowEnv { + public: + explicit DeployLongWindowAnyWhereEnv(sdk::SQLClusterRouter* sr) : DeployLongWindowEnv(sr) {} + ~DeployLongWindowAnyWhereEnv() override {} + + void PrepareSchema() override { + ProcessSQLs( + sr_, {"SET @@execute_mode='online';", absl::StrCat("create database ", db_), absl::StrCat("use ", db_), + absl::StrCat("create table ", table_, + R"((col1 string, col2 string, col3 timestamp, i64_col bigint, + i16_col smallint, i32_col int, f_col float, + d_col double, t_col timestamp, s_col string, + date_col date, filter int, + index(key=(col1,col2), ts=col3, abs_ttl=0, ttl_type=absolute) + ) options(storage_mode = 'HDD'))")}); + } + + void Deploy() override { + hybridse::sdk::Status status; + sr_->ExecuteSQL(absl::Substitute(R"s(DEPLOY $0 options(long_windows='w1:3s') + SELECT + col1, col2, + avg_where(i64_col, col1!='str1') over w1 as m1, + avg_where(i16_col, filter<1) over w1 as m2, + avg_where(i32_col, filter = null) over w1 as m3, + avg_where(f_col, 0=filter) over w1 as m4, + avg_where(d_col, f_col = 11) over w1 as m5, + avg_where(i64_col, i16_col > 10) over w1 as m6, + avg_where(i16_col, i32_col = 10) over w1 as m7, + avg_where(i32_col, f_col != 7) over w1 as m8, + avg_where(f_col, d_col <= 10) over w1 as m9, + avg_where(d_col, d_col < 4.5) over w1 as m10, + FROM $1 WINDOW + w1 AS (PARTITION BY col1,col2 ORDER BY col3 ROWS_RANGE BETWEEN 7s PRECEDING AND CURRENT ROW))s", + dp_, table_), + &status); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.msg, "avg_where only support over memory base table") + << "code=" << status.code << ", msg=" << status.msg << "\n" + << status.trace; + } + + void TearDownPreAggTables() override {} + }; + + DeployLongWindowAnyWhereEnv env(sr); + env.SetUp(); + absl::Cleanup clean = [&env]() { env.TearDown(); }; +} + TEST_P(DBSDKTest, LongWindowsCleanup) { auto cli = GetParam(); cs = cli->cs; diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index cf8fb68f8b2..e097343c5bc 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -26,6 +26,7 @@ #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/strip.h" +#include "absl/strings/substitute.h" #include "base/ddl_parser.h" #include "base/file_util.h" #include "boost/none.hpp" @@ -3506,6 +3507,31 @@ hybridse::sdk::Status SQLClusterRouter::HandleLongWindows( std::string meta_table = openmldb::nameserver::PRE_AGG_META_NAME; std::string aggr_db = openmldb::nameserver::PRE_AGG_DB; for (const auto& lw : long_window_infos) { + if (absl::EndsWithIgnoreCase(lw.aggr_func_, "_where")) { + // TOOD(ace): *_where op only support for memory base table + if (tables[0].storage_mode() != common::StorageMode::kMemory) { + return {base::ReturnCode::kError, + absl::StrCat(lw.aggr_func_, " only support over memory base table")}; + } + + // TODO(#2313): *_where for rows bucket should support later + if (openmldb::base::IsNumber(long_window_map.at(lw.window_name_))) { + return {base::ReturnCode::kError, absl::StrCat("unsupport *_where op (", lw.aggr_func_, + ") for rows bucket type long window")}; + } + + // unsupport filter col of date/timestamp + for (size_t i = 0; i < tables[0].column_desc_size(); ++i) { + if (lw.filter_col_ == tables[0].column_desc(i).name()) { + auto type = tables[0].column_desc(i).data_type(); + if (type == type::DataType::kDate || type == type::DataType::kTimestamp) { + return { + base::ReturnCode::kError, + absl::Substitute("unsupport date or timestamp as filer column ($0)", lw.filter_col_)}; + } + } + } + } // check if pre-aggr table exists ::hybridse::sdk::Status status; bool is_exist = CheckPreAggrTableExist(base_table, base_db, lw, &status); @@ -3542,10 +3568,10 @@ hybridse::sdk::Status SQLClusterRouter::HandleLongWindows( return {base::ReturnCode::kError, "get tablets failed"}; } auto base_table_info = cluster_sdk_->GetTableInfo(base_db, base_table); - auto aggr_id = cluster_sdk_->GetTableId(aggr_db, aggr_table); if (!base_table_info) { return {base::ReturnCode::kError, "get table info failed"}; } + auto aggr_id = cluster_sdk_->GetTableId(aggr_db, aggr_table); ::openmldb::api::TableMeta base_table_meta; base_table_meta.set_db(base_table_info->db()); base_table_meta.set_name(base_table_info->name()); diff --git a/src/storage/aggregator.cc b/src/storage/aggregator.cc index 0bcf67750ce..ea63faa6480 100644 --- a/src/storage/aggregator.cc +++ b/src/storage/aggregator.cc @@ -14,17 +14,19 @@ * limitations under the License. */ +#include "storage/aggregator.h" + #include #include -#include "absl/strings/str_cat.h" -#include "boost/algorithm/string.hpp" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "base/file_util.h" #include "base/glog_wapper.h" #include "base/slice.h" #include "base/strings.h" +#include "boost/algorithm/string.hpp" #include "common/timer.h" -#include "storage/aggregator.h" #include "storage/table.h" DECLARE_bool(binlog_notify_on_put); @@ -124,17 +126,23 @@ bool Aggregator::Update(const std::string& key, const std::string& row, const ui } } + if (!filter_key.empty() && window_type_ != WindowType::kRowsRange) { + LOG(ERROR) << "unsupport rows bucket window for *_where agg op"; + return false; + } + AggrBufferLocked* aggr_buffer_lock; { std::lock_guard lock(mu_); auto it = aggr_buffer_map_.find(key); if (it == aggr_buffer_map_.end()) { - auto insert_pair = aggr_buffer_map_[key].insert(std::make_pair(filter_key, AggrBufferLocked{})); + auto insert_pair = aggr_buffer_map_[key].emplace(filter_key, AggrBufferLocked{}); aggr_buffer_lock = &insert_pair.first->second; } else { - auto filter_it = it->second.find(filter_key); - if (filter_it == it->second.end()) { - auto insert_pair = it->second.emplace(filter_key, AggrBufferLocked{}); + auto& filter_map = it->second; + auto filter_it = filter_map.find(filter_key); + if (filter_it == filter_map.end()) { + auto insert_pair = filter_map.emplace(filter_key, AggrBufferLocked{}); aggr_buffer_lock = &insert_pair.first->second; } else { aggr_buffer_lock = &filter_it->second; @@ -360,6 +368,18 @@ bool Aggregator::GetAggrBuffer(const std::string& key, const std::string& filter return true; } +bool Aggregator::SetFilter(absl::string_view filter_col) { + for (int i = 0; i < base_table_schema_.size(); i++) { + if (base_table_schema_.Get(i).name() == filter_col) { + filter_col_ = filter_col; + filter_col_idx_ = i; + return true; + } + } + + return false; +} + bool Aggregator::GetAggrBufferFromRowView(const codec::RowView& row_view, const int8_t* row_ptr, AggrBuffer* buffer) { if (buffer == nullptr) { return false; @@ -926,24 +946,6 @@ bool CountAggregator::UpdateAggrVal(const codec::RowView& row_view, const int8_t return true; } -CountWhereAggregator::CountWhereAggregator(const ::openmldb::api::TableMeta& base_meta, - const ::openmldb::api::TableMeta& aggr_meta, - std::shared_ptr aggr_table, - std::shared_ptr aggr_replicator, const uint32_t& index_pos, - const std::string& aggr_col, const AggrType& aggr_type, - const std::string& ts_col, WindowType window_tpye, uint32_t window_size, - const std::string& filter_col) - : CountAggregator(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, aggr_type, ts_col, - window_tpye, window_size) { - filter_col_ = filter_col; - for (int i = 0; i < base_meta.column_desc().size(); i++) { - if (base_meta.column_desc(i).name() == filter_col_) { - filter_col_idx_ = i; - break; - } - } -} - AvgAggregator::AvgAggregator(const ::openmldb::api::TableMeta& base_meta, const ::openmldb::api::TableMeta& aggr_meta, std::shared_ptr
aggr_table, std::shared_ptr aggr_replicator, const uint32_t& index_pos, const std::string& aggr_col, const AggrType& aggr_type, @@ -1031,14 +1033,14 @@ std::shared_ptr CreateAggregator(const ::openmldb::api::TableMeta& b window_type = WindowType::kRowsRange; if (bucket_size.empty()) { PDLOG(ERROR, "Bucket size is empty"); - return std::shared_ptr(); + return {}; } char time_unit = tolower(bucket_size.back()); std::string time_size = bucket_size.substr(0, bucket_size.size() - 1); boost::trim(time_size); if (!::openmldb::base::IsNumber(time_size)) { PDLOG(ERROR, "Bucket size is not a number"); - return std::shared_ptr(); + return {}; } switch (time_unit) { case 's': @@ -1055,39 +1057,47 @@ std::shared_ptr CreateAggregator(const ::openmldb::api::TableMeta& b break; default: { PDLOG(ERROR, "Unsupported time unit"); - return std::shared_ptr(); + return {}; } } } - if (aggr_type == "sum") { - return std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, + std::shared_ptr agg; + if (aggr_type == "sum" || aggr_type == "sum_where") { + agg = std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, AggrType::kSum, ts_col, window_type, window_size); - } else if (aggr_type == "min") { - return std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, + } else if (aggr_type == "min" || aggr_type == "min_where") { + agg = std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, AggrType::kMin, ts_col, window_type, window_size); - } else if (aggr_type == "max") { - return std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, - AggrType::kMax, ts_col, window_type, window_size); - } else if (aggr_type == "count") { - return std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, - AggrType::kCount, ts_col, window_type, window_size); - } else if (aggr_type == "avg") { - return std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, - AggrType::kAvg, ts_col, window_type, window_size); - } else if (aggr_type == "count_where") { - if (filter_col.empty()) { - PDLOG(ERROR, "no filter column specified for count_where"); - return std::shared_ptr(); - } - return std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, - aggr_col, AggrType::kCountWhere, ts_col, window_type, window_size, - filter_col); + } else if (aggr_type == "max" || aggr_type == "max_where") { + agg = std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, + AggrType::kMax, ts_col, window_type, window_size); + } else if (aggr_type == "count" || aggr_type == "count_where") { + agg = std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, + AggrType::kCount, ts_col, window_type, window_size); + } else if (aggr_type == "avg" || aggr_type == "avg_where") { + agg = std::make_shared(base_meta, aggr_meta, aggr_table, aggr_replicator, index_pos, aggr_col, + AggrType::kAvg, ts_col, window_type, window_size); } else { PDLOG(ERROR, "Unsupported aggregate function type"); - return std::shared_ptr(); + return {}; + } + + if (filter_col.empty() || !absl::EndsWithIgnoreCase(aggr_type, "_where")) { + // min/max/count/avg/sum ops + return agg; + } + + // _where variant + if (filter_col.empty()) { + PDLOG(ERROR, "no filter column specified for %s", aggr_type); + return {}; + } + if (!agg->SetFilter(filter_col)) { + PDLOG(ERROR, "can not find filter column '%s' for %s", filter_col, aggr_type); + return {}; } - return std::shared_ptr(); + return agg; } } // namespace storage diff --git a/src/storage/aggregator.h b/src/storage/aggregator.h index 36fbf407f80..1552e936df3 100644 --- a/src/storage/aggregator.h +++ b/src/storage/aggregator.h @@ -43,7 +43,6 @@ enum class AggrType { kMax = 3, kCount = 4, kAvg = 5, - kCountWhere = 6, }; enum class WindowType { @@ -152,6 +151,9 @@ class Aggregator { uint32_t GetAggrTid() { return aggr_table_->GetId(); } + // set the filter column info that not initialized in constructor + bool SetFilter(absl::string_view filter_col); + protected: codec::Schema base_table_schema_; codec::Schema aggr_table_schema_; @@ -282,17 +284,6 @@ class CountAggregator : public Aggregator { bool count_all = false; }; -class CountWhereAggregator : public CountAggregator { - public: - CountWhereAggregator(const ::openmldb::api::TableMeta& base_meta, const ::openmldb::api::TableMeta& aggr_meta, - std::shared_ptr
aggr_table, std::shared_ptr aggr_replicator, - const uint32_t& index_pos, const std::string& aggr_col, const AggrType& aggr_type, - const std::string& ts_col, WindowType window_tpye, uint32_t window_size, - const std::string& filter_col); - - ~CountWhereAggregator() = default; -}; - class AvgAggregator : public Aggregator { public: AvgAggregator(const ::openmldb::api::TableMeta& base_meta, const ::openmldb::api::TableMeta& aggr_meta, diff --git a/src/tablet/tablet_impl.cc b/src/tablet/tablet_impl.cc index afae8599370..c9b4972fe45 100644 --- a/src/tablet/tablet_impl.cc +++ b/src/tablet/tablet_impl.cc @@ -4387,6 +4387,8 @@ bool TabletImpl::RefreshAggrCatalog() { table_info.order_by_col.assign(str, len); row_view.GetValue(row.buf(), 8, &str, &len); table_info.bucket_size.assign(str, len); + row_view.GetValue(row.buf(), 9, &str, &len); + table_info.filter_col.assign(str, len); table_infos.emplace_back(std::move(table_info)); it->Next();