diff --git a/dbms/src/Flash/Coprocessor/DAGContext.cpp b/dbms/src/Flash/Coprocessor/DAGContext.cpp index 1736e0b6cec..1ef7338a589 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.cpp +++ b/dbms/src/Flash/Coprocessor/DAGContext.cpp @@ -205,60 +205,13 @@ void DAGContext::attachBlockIO(const BlockIO & io_) { io = io_; } -void DAGContext::initExchangeReceiverIfMPP(Context & context, size_t max_streams) -{ - if (isMPPTask()) - { - if (mpp_exchange_receiver_map_inited) - throw TiFlashException("Repeatedly initialize mpp_exchange_receiver_map", Errors::Coprocessor::Internal); - traverseExecutors(dag_request, [&](const tipb::Executor & executor) { - if (executor.tp() == tipb::ExecType::TypeExchangeReceiver) - { - assert(executor.has_executor_id()); - const auto & executor_id = executor.executor_id(); - // In order to distinguish different exchange receivers. - auto exchange_receiver = std::make_shared( - std::make_shared( - executor.exchange_receiver(), - getMPPTaskMeta(), - context.getTMTContext().getKVCluster(), - context.getTMTContext().getMPPTaskManager(), - context.getSettingsRef().enable_local_tunnel, - context.getSettingsRef().enable_async_grpc_client), - executor.exchange_receiver().encoded_task_meta_size(), - max_streams, - log->identifier(), - executor_id); - mpp_exchange_receiver_map[executor_id] = exchange_receiver; - new_thread_count_of_exchange_receiver += exchange_receiver->computeNewThreadCount(); - } - return true; - }); - mpp_exchange_receiver_map_inited = true; - } -} - const std::unordered_map> & DAGContext::getMPPExchangeReceiverMap() const { if (!isMPPTask()) throw TiFlashException("mpp_exchange_receiver_map is used in mpp only", Errors::Coprocessor::Internal); - if (!mpp_exchange_receiver_map_inited) - throw TiFlashException("mpp_exchange_receiver_map has not been initialized", Errors::Coprocessor::Internal); - return mpp_exchange_receiver_map; -} - -void DAGContext::cancelAllExchangeReceiver() -{ - for (auto & it : mpp_exchange_receiver_map) - { - it.second->cancel(); - } -} - -int DAGContext::getNewThreadCountOfExchangeReceiver() const -{ - return new_thread_count_of_exchange_receiver; + RUNTIME_ASSERT(mpp_exchange_receiver_map != nullptr, log, "MPPTask without exchange receiver map"); + return *mpp_exchange_receiver_map; } bool DAGContext::containsRegionsInfoForTable(Int64 table_id) const diff --git a/dbms/src/Flash/Coprocessor/DAGContext.h b/dbms/src/Flash/Coprocessor/DAGContext.h index c20eb3a367e..07b65b2d8fe 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.h +++ b/dbms/src/Flash/Coprocessor/DAGContext.h @@ -37,6 +37,8 @@ namespace DB class Context; class MPPTunnelSet; class ExchangeReceiver; +using ExchangeReceiverMap = std::unordered_map>; +using ExchangeReceiverMapPtr = std::shared_ptr>>; class Join; using JoinPtr = std::shared_ptr; @@ -254,7 +256,6 @@ class DAGContext return io; } - int getNewThreadCountOfExchangeReceiver() const; UInt64 getFlags() const { return flags; @@ -303,10 +304,11 @@ class DAGContext bool columnsForTestEmpty() { return columns_for_test_map.empty(); } - void cancelAllExchangeReceiver(); - - void initExchangeReceiverIfMPP(Context & context, size_t max_streams); const std::unordered_map> & getMPPExchangeReceiverMap() const; + void setMPPExchangeReceiverMap(ExchangeReceiverMapPtr & exchange_receiver_map) + { + mpp_exchange_receiver_map = exchange_receiver_map; + } void addSubquery(const String & subquery_id, SubqueryForSet && subquery); bool hasSubquery() const { return !subqueries.empty(); } @@ -367,10 +369,8 @@ class DAGContext ConcurrentBoundedQueue warnings; /// warning_count is the actual warning count during the entire execution std::atomic warning_count; - int new_thread_count_of_exchange_receiver = 0; /// key: executor_id of ExchangeReceiver nodes in dag. - std::unordered_map> mpp_exchange_receiver_map; - bool mpp_exchange_receiver_map_inited = false; + ExchangeReceiverMapPtr mpp_exchange_receiver_map; /// vector of SubqueriesForSets(such as join build subquery). /// The order of the vector is also the order of the subquery. std::vector subqueries; diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index bea26fe9f99..9ffa29cd14d 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -1432,6 +1432,7 @@ tipb::EncodeType analyzeDAGEncodeType(DAGContext & dag_context) return tipb::EncodeType::TypeDefault; return encode_type; } + tipb::ScalarFuncSig reverseGetFuncSigByFuncName(const String & name) { static std::unordered_map func_name_sig_map = getFuncNameToSigMap(); diff --git a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp index a67ebf20aa5..0e767d65d77 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp @@ -24,19 +24,8 @@ namespace DB InterpreterDAG::InterpreterDAG(Context & context_, const DAGQuerySource & dag_) : context(context_) , dag(dag_) + , max_streams(context.getMaxStreams()) { - const Settings & settings = context.getSettingsRef(); - if (dagContext().isBatchCop() || (dagContext().isMPPTask() && !dagContext().isTest())) - max_streams = settings.max_threads; - else if (dagContext().isTest()) - max_streams = dagContext().initialize_concurrency; - else - max_streams = 1; - - if (max_streams > 1) - { - max_streams *= settings.max_streams_to_max_threads_ratio; - } } void setRestorePipelineConcurrency(DAGQueryBlock & query_block) @@ -75,10 +64,6 @@ BlockInputStreams InterpreterDAG::executeQueryBlock(DAGQueryBlock & query_block) BlockIO InterpreterDAG::execute() { - /// Due to learner read, DAGQueryBlockInterpreter may take a long time to build - /// the query plan, so we init mpp exchange receiver before executeQueryBlock - dagContext().initExchangeReceiverIfMPP(context, max_streams); - BlockInputStreams streams = executeQueryBlock(*dag.getRootQueryBlock()); DAGPipeline pipeline; pipeline.streams = streams; diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 8f9ca8e55e5..40f03ff79ba 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -22,11 +22,14 @@ #include #include #include +#include +#include #include #include #include #include #include +#include #include #include #include @@ -94,13 +97,73 @@ void MPPTask::run() newThreadManager()->scheduleThenDetach(true, "MPPTask", [self = shared_from_this()] { self->runImpl(); }); } -void MPPTask::registerTunnel(const MPPTaskId & task_id, MPPTunnelPtr tunnel) +void MPPTask::registerTunnels(const mpp::DispatchTaskRequest & task_request) { - if (status == CANCELLED) - throw Exception("the tunnel " + tunnel->id() + " can not been registered, because the task is cancelled"); + tunnel_set = std::make_shared(log->identifier()); + std::chrono::seconds timeout(task_request.timeout()); + const auto & exchange_sender = dag_req.root_executor().exchange_sender(); - RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set"); - tunnel_set->registerTunnel(task_id, tunnel); + for (int i = 0; i < exchange_sender.encoded_task_meta_size(); ++i) + { + // exchange sender will register the tunnels and wait receiver to found a connection. + mpp::TaskMeta task_meta; + if (unlikely(!task_meta.ParseFromString(exchange_sender.encoded_task_meta(i)))) + throw TiFlashException("Failed to decode task meta info in ExchangeSender", Errors::Coprocessor::BadRequest); + bool is_local = context->getSettingsRef().enable_local_tunnel && meta.address() == task_meta.address(); + bool is_async = !is_local && context->getSettingsRef().enable_async_server; + MPPTunnelPtr tunnel = std::make_shared(task_meta, task_request.meta(), timeout, context->getSettingsRef().max_threads, is_local, is_async, log->identifier()); + LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id()); + if (status != INITIALIZING) + throw Exception(fmt::format("The tunnel {} can not be registered, because the task is not in initializing state", tunnel->id())); + tunnel_set->registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel); + if (!dag_context->isRootMPPTask()) + { + FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_register_tunnel_for_non_root_mpp_task); + } + } +} + +void MPPTask::initExchangeReceivers() +{ + mpp_exchange_receiver_map = std::make_shared(); + traverseExecutors(&dag_req, [&](const tipb::Executor & executor) { + if (executor.tp() == tipb::ExecType::TypeExchangeReceiver) + { + assert(executor.has_executor_id()); + const auto & executor_id = executor.executor_id(); + // In order to distinguish different exchange receivers. + auto exchange_receiver = std::make_shared( + std::make_shared( + executor.exchange_receiver(), + dag_context->getMPPTaskMeta(), + context->getTMTContext().getKVCluster(), + context->getTMTContext().getMPPTaskManager(), + context->getSettingsRef().enable_local_tunnel, + context->getSettingsRef().enable_async_grpc_client), + executor.exchange_receiver().encoded_task_meta_size(), + context->getMaxStreams(), + log->identifier(), + executor_id); + if (status != RUNNING) + throw Exception("exchange receiver map can not be initialized, because the task is not in running state"); + + (*mpp_exchange_receiver_map)[executor_id] = exchange_receiver; + new_thread_count_of_exchange_receiver += exchange_receiver->computeNewThreadCount(); + } + return true; + }); + dag_context->setMPPExchangeReceiverMap(mpp_exchange_receiver_map); +} + +void MPPTask::cancelAllExchangeReceivers() +{ + if (likely(mpp_exchange_receiver_map != nullptr)) + { + for (auto & it : *mpp_exchange_receiver_map) + { + it.second->cancel(); + } + } } std::pair MPPTask::getTunnel(const ::mpp::EstablishMPPConnectionRequest * request) @@ -116,7 +179,7 @@ std::pair MPPTask::getTunnel(const ::mpp::EstablishMPPConn MPPTaskId receiver_id{request->receiver_meta().start_ts(), request->receiver_meta().task_id()}; RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set"); - auto tunnel_ptr = tunnel_set->getTunnelById(receiver_id); + auto tunnel_ptr = tunnel_set->getTunnelByReceiverTaskId(receiver_id); if (tunnel_ptr == nullptr) { auto err_msg = fmt::format( @@ -207,25 +270,8 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request) } // register tunnels - tunnel_set = std::make_shared(log->identifier()); - std::chrono::seconds timeout(task_request.timeout()); + registerTunnels(task_request); - for (int i = 0; i < exchange_sender.encoded_task_meta_size(); i++) - { - // exchange sender will register the tunnels and wait receiver to found a connection. - mpp::TaskMeta task_meta; - if (!task_meta.ParseFromString(exchange_sender.encoded_task_meta(i))) - throw TiFlashException("Failed to decode task meta info in ExchangeSender", Errors::Coprocessor::BadRequest); - bool is_local = context->getSettingsRef().enable_local_tunnel && meta.address() == task_meta.address(); - bool is_async = !is_local && context->getSettingsRef().enable_async_server; - MPPTunnelPtr tunnel = std::make_shared(task_meta, task_request.meta(), timeout, context->getSettingsRef().max_threads, is_local, is_async, log->identifier()); - LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id()); - registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel); - if (!dag_context->isRootMPPTask()) - { - FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_register_tunnel_for_non_root_mpp_task); - } - } dag_context->tunnel_set = tunnel_set; // register task. auto task_manager = tmt_context.getMPPTaskManager(); @@ -251,6 +297,7 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request) void MPPTask::preprocess() { auto start_time = Clock::now(); + initExchangeReceivers(); DAGQuerySource dag(*context); executeQuery(dag, *context, false, QueryProcessingStage::Complete); auto end_time = Clock::now(); @@ -280,7 +327,7 @@ void MPPTask::runImpl() LOG_FMT_INFO(log, "task starts preprocessing"); preprocess(); needed_threads = estimateCountOfNewThreads(); - LOG_FMT_DEBUG(log, "Estimate new thread count of query :{} including tunnel_threads: {} , receiver_threads: {}", needed_threads, dag_context->tunnel_set->getRemoteTunnelCnt(), dag_context->getNewThreadCountOfExchangeReceiver()); + LOG_FMT_DEBUG(log, "Estimate new thread count of query :{} including tunnel_threads: {} , receiver_threads: {}", needed_threads, dag_context->tunnel_set->getRemoteTunnelCnt(), new_thread_count_of_exchange_receiver); scheduleOrWait(); @@ -346,8 +393,7 @@ void MPPTask::runImpl() else { context->getProcessList().sendCancelToQuery(context->getCurrentQueryId(), context->getClientInfo().current_user, true); - if (dag_context) - dag_context->cancelAllExchangeReceiver(); + cancelAllExchangeReceivers(); writeErrToAllTunnels(err_msg); } LOG_FMT_INFO(log, "task ends, time cost is {} ms.", stopwatch.elapsedMilliseconds()); diff --git a/dbms/src/Flash/Mpp/MPPTask.h b/dbms/src/Flash/Mpp/MPPTask.h index ee434a2f2ff..c8423ac484c 100644 --- a/dbms/src/Flash/Mpp/MPPTask.h +++ b/dbms/src/Flash/Mpp/MPPTask.h @@ -62,8 +62,6 @@ class MPPTask : public std::enable_shared_from_this void run(); - void registerTunnel(const MPPTaskId & id, MPPTunnelPtr tunnel); - int getNeededThreads(); enum class ScheduleState @@ -107,6 +105,12 @@ class MPPTask : public std::enable_shared_from_this int estimateCountOfNewThreads(); + void registerTunnels(const mpp::DispatchTaskRequest & task_request); + + void initExchangeReceivers(); + + void cancelAllExchangeReceivers(); + tipb::DAGRequest dag_req; ContextPtr context; @@ -122,6 +126,10 @@ class MPPTask : public std::enable_shared_from_this MPPTaskId id; MPPTunnelSetPtr tunnel_set; + /// key: executor_id of ExchangeReceiver nodes in dag. + ExchangeReceiverMapPtr mpp_exchange_receiver_map; + + int new_thread_count_of_exchange_receiver = 0; MPPTaskManager * manager = nullptr; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp index 500e9501b08..8d709bb7d38 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp @@ -133,12 +133,12 @@ void MPPTunnelSetBase::writeError(const String & msg) } template -void MPPTunnelSetBase::registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel) +void MPPTunnelSetBase::registerTunnel(const MPPTaskId & receiver_task_id, const TunnelPtr & tunnel) { - if (id_to_index_map.find(id) != id_to_index_map.end()) - throw Exception("the tunnel " + tunnel->id() + " has been registered"); + if (receiver_task_id_to_index_map.find(receiver_task_id) != receiver_task_id_to_index_map.end()) + throw Exception(fmt::format("the tunnel {} has been registered", tunnel->id())); - id_to_index_map[id] = tunnels.size(); + receiver_task_id_to_index_map[receiver_task_id] = tunnels.size(); tunnels.push_back(tunnel); if (!tunnel->isLocal()) { @@ -163,10 +163,10 @@ void MPPTunnelSetBase::finishWrite() } template -typename MPPTunnelSetBase::TunnelPtr MPPTunnelSetBase::getTunnelById(const MPPTaskId & id) +typename MPPTunnelSetBase::TunnelPtr MPPTunnelSetBase::getTunnelByReceiverTaskId(const MPPTaskId & id) { - auto it = id_to_index_map.find(id); - if (it == id_to_index_map.end()) + auto it = receiver_task_id_to_index_map.find(id); + if (it == receiver_task_id_to_index_map.end()) { return nullptr; } diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.h b/dbms/src/Flash/Mpp/MPPTunnelSet.h index 021c609f516..e4123db1be5 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.h @@ -59,7 +59,7 @@ class MPPTunnelSetBase : private boost::noncopyable void finishWrite(); void registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel); - TunnelPtr getTunnelById(const MPPTaskId & id); + TunnelPtr getTunnelByReceiverTaskId(const MPPTaskId & id); uint16_t getPartitionNum() const { return tunnels.size(); } @@ -72,7 +72,7 @@ class MPPTunnelSetBase : private boost::noncopyable private: std::vector tunnels; - std::unordered_map id_to_index_map; + std::unordered_map receiver_task_id_to_index_map; const LoggerPtr log; int remote_tunnel_cnt = 0; diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index a0adef5b50d..3beedbd3601 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -1879,6 +1880,30 @@ SharedQueriesPtr Context::getSharedQueries() return shared->shared_queries; } +size_t Context::getMaxStreams() const +{ + size_t max_streams = settings.max_threads; + bool is_cop_request = false; + if (dag_context != nullptr) + { + if (dag_context->isTest()) + max_streams = dag_context->initialize_concurrency; + else if (!dag_context->isBatchCop() && !dag_context->isMPPTask()) + { + is_cop_request = true; + max_streams = 1; + } + } + if (max_streams > 1) + max_streams *= settings.max_streams_to_max_threads_ratio; + if (max_streams == 0) + max_streams = 1; + if (unlikely(max_streams != 1 && is_cop_request)) + /// for cop request, the max_streams should be 1 + throw Exception("Cop request only support running with max_streams = 1"); + return max_streams; +} + SessionCleaner::~SessionCleaner() { try diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 5d5c39263c6..b6e759e364b 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -459,6 +459,8 @@ class Context void reloadDeltaTreeConfig(const Poco::Util::AbstractConfiguration & config); + size_t getMaxStreams() const; + private: /** Check if the current client has access to the specified database. * If access is denied, throw an exception.