diff --git a/dbms/src/Flash/Coprocessor/DAGContext.cpp b/dbms/src/Flash/Coprocessor/DAGContext.cpp index 1ef7338a589..ec0544c6ee4 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.cpp +++ b/dbms/src/Flash/Coprocessor/DAGContext.cpp @@ -206,12 +206,20 @@ void DAGContext::attachBlockIO(const BlockIO & io_) io = io_; } -const std::unordered_map> & DAGContext::getMPPExchangeReceiverMap() const +ExchangeReceiverPtr DAGContext::getMPPExchangeReceiver(const String & executor_id) const { if (!isMPPTask()) throw TiFlashException("mpp_exchange_receiver_map is used in mpp only", Errors::Coprocessor::Internal); - RUNTIME_ASSERT(mpp_exchange_receiver_map != nullptr, log, "MPPTask without exchange receiver map"); - return *mpp_exchange_receiver_map; + RUNTIME_ASSERT(mpp_receiver_set != nullptr, log, "MPPTask without receiver set"); + return mpp_receiver_set->getExchangeReceiver(executor_id); +} + +void DAGContext::addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader) +{ + if (!isMPPTask()) + return; + RUNTIME_ASSERT(mpp_receiver_set != nullptr, log, "MPPTask without receiver set"); + return mpp_receiver_set->addCoprocessorReader(coprocessor_reader); } bool DAGContext::containsRegionsInfoForTable(Int64 table_id) const diff --git a/dbms/src/Flash/Coprocessor/DAGContext.h b/dbms/src/Flash/Coprocessor/DAGContext.h index 07b65b2d8fe..8b94d4637a8 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.h +++ b/dbms/src/Flash/Coprocessor/DAGContext.h @@ -37,8 +37,13 @@ namespace DB class Context; class MPPTunnelSet; class ExchangeReceiver; -using ExchangeReceiverMap = std::unordered_map>; -using ExchangeReceiverMapPtr = std::shared_ptr>>; +using ExchangeReceiverPtr = std::shared_ptr; +/// key: executor_id of ExchangeReceiver nodes in dag. +using ExchangeReceiverMap = std::unordered_map; +class MPPReceiverSet; +using MPPReceiverSetPtr = std::shared_ptr; +class CoprocessorReader; +using CoprocessorReaderPtr = std::shared_ptr; class Join; using JoinPtr = std::shared_ptr; @@ -304,11 +309,12 @@ class DAGContext bool columnsForTestEmpty() { return columns_for_test_map.empty(); } - const std::unordered_map> & getMPPExchangeReceiverMap() const; - void setMPPExchangeReceiverMap(ExchangeReceiverMapPtr & exchange_receiver_map) + ExchangeReceiverPtr getMPPExchangeReceiver(const String & executor_id) const; + void setMPPReceiverSet(const MPPReceiverSetPtr & receiver_set) { - mpp_exchange_receiver_map = exchange_receiver_map; + mpp_receiver_set = receiver_set; } + void addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader); void addSubquery(const String & subquery_id, SubqueryForSet && subquery); bool hasSubquery() const { return !subqueries.empty(); } @@ -369,8 +375,8 @@ class DAGContext ConcurrentBoundedQueue warnings; /// warning_count is the actual warning count during the entire execution std::atomic warning_count; - /// key: executor_id of ExchangeReceiver nodes in dag. - ExchangeReceiverMapPtr mpp_exchange_receiver_map; + + MPPReceiverSetPtr mpp_receiver_set; /// vector of SubqueriesForSets(such as join build subquery). /// The order of the vector is also the order of the subquery. std::vector subqueries; diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index 86d6428c92a..e322a830744 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -481,14 +481,14 @@ void DAGQueryBlockInterpreter::recordProfileStreams(DAGPipeline & pipeline, cons void DAGQueryBlockInterpreter::handleExchangeReceiver(DAGPipeline & pipeline) { - auto it = dagContext().getMPPExchangeReceiverMap().find(query_block.source_name); - if (unlikely(it == dagContext().getMPPExchangeReceiverMap().end())) + auto exchange_receiver = dagContext().getMPPExchangeReceiver(query_block.source_name); + if (unlikely(exchange_receiver == nullptr)) throw Exception("Can not find exchange receiver for " + query_block.source_name, ErrorCodes::LOGICAL_ERROR); // todo choose a more reasonable stream number auto & exchange_receiver_io_input_streams = dagContext().getInBoundIOInputStreamsMap()[query_block.source_name]; for (size_t i = 0; i < max_streams; ++i) { - BlockInputStreamPtr stream = std::make_shared(it->second, log->identifier(), query_block.source_name); + BlockInputStreamPtr stream = std::make_shared(exchange_receiver, log->identifier(), query_block.source_name); exchange_receiver_io_input_streams.push_back(stream); stream = std::make_shared(stream, 8192, 0, log->identifier()); stream->setExtraInfo("squashing after exchange receiver"); diff --git a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp index 14cddd94730..ad2de7217e0 100644 --- a/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp @@ -486,6 +486,7 @@ void DAGStorageInterpreter::buildRemoteStreams(std::vector && rem std::vector tasks(all_tasks.begin() + task_start, all_tasks.begin() + task_end); auto coprocessor_reader = std::make_shared(schema, cluster, tasks, has_enforce_encode_type, 1); + context.getDAGContext()->addCoprocessorReader(coprocessor_reader); BlockInputStreamPtr input = std::make_shared(coprocessor_reader, log->identifier(), table_scan.getTableScanExecutorID()); pipeline.streams.push_back(input); task_start = task_end; diff --git a/dbms/src/Flash/Mpp/MPPReceiverSet.cpp b/dbms/src/Flash/Mpp/MPPReceiverSet.cpp new file mode 100644 index 00000000000..60cca308c18 --- /dev/null +++ b/dbms/src/Flash/Mpp/MPPReceiverSet.cpp @@ -0,0 +1,48 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB +{ +void MPPReceiverSet::addExchangeReceiver(const String & executor_id, const ExchangeReceiverPtr & exchange_receiver) +{ + RUNTIME_ASSERT(exchange_receiver_map.find(executor_id) == exchange_receiver_map.end(), log, "Duplicate executor_id: {} in DAGRequest", executor_id); + exchange_receiver_map[executor_id] = exchange_receiver; +} + +void MPPReceiverSet::addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader) +{ + coprocessor_readers.push_back(coprocessor_reader); +} + +ExchangeReceiverPtr MPPReceiverSet::getExchangeReceiver(const String & executor_id) const +{ + auto it = exchange_receiver_map.find(executor_id); + if (unlikely(it == exchange_receiver_map.end())) + return nullptr; + return it->second; +} + +void MPPReceiverSet::cancel() +{ + for (auto & it : exchange_receiver_map) + { + it.second->cancel(); + } + for (auto & cop_reader : coprocessor_readers) + cop_reader->cancel(); +} +} // namespace DB diff --git a/dbms/src/Flash/Mpp/MPPReceiverSet.h b/dbms/src/Flash/Mpp/MPPReceiverSet.h new file mode 100644 index 00000000000..44274cb3ce8 --- /dev/null +++ b/dbms/src/Flash/Mpp/MPPReceiverSet.h @@ -0,0 +1,44 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB +{ +class MPPReceiverSet +{ +public: + explicit MPPReceiverSet(const String & req_id) + : log(Logger::get("MPPReceiverSet", req_id)) + {} + void addExchangeReceiver(const String & executor_id, const ExchangeReceiverPtr & exchange_receiver); + void addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader); + ExchangeReceiverPtr getExchangeReceiver(const String & executor_id) const; + void cancel(); + +private: + /// two kinds of receiver in MPP + /// ExchangeReceiver: receiver data from other MPPTask + /// CoprocessorReader: used in remote read + ExchangeReceiverMap exchange_receiver_map; + std::vector coprocessor_readers; + const LoggerPtr log; +}; + +using MPPReceiverSetPtr = std::shared_ptr; + +} // namespace DB diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 40f03ff79ba..0381bbdfa04 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -125,7 +125,7 @@ void MPPTask::registerTunnels(const mpp::DispatchTaskRequest & task_request) void MPPTask::initExchangeReceivers() { - mpp_exchange_receiver_map = std::make_shared(); + receiver_set = std::make_shared(log->identifier()); traverseExecutors(&dag_req, [&](const tipb::Executor & executor) { if (executor.tp() == tipb::ExecType::TypeExchangeReceiver) { @@ -147,22 +147,19 @@ void MPPTask::initExchangeReceivers() if (status != RUNNING) throw Exception("exchange receiver map can not be initialized, because the task is not in running state"); - (*mpp_exchange_receiver_map)[executor_id] = exchange_receiver; + receiver_set->addExchangeReceiver(executor_id, exchange_receiver); new_thread_count_of_exchange_receiver += exchange_receiver->computeNewThreadCount(); } return true; }); - dag_context->setMPPExchangeReceiverMap(mpp_exchange_receiver_map); + dag_context->setMPPReceiverSet(receiver_set); } -void MPPTask::cancelAllExchangeReceivers() +void MPPTask::cancelAllReceivers() { - if (likely(mpp_exchange_receiver_map != nullptr)) + if (likely(receiver_set != nullptr)) { - for (auto & it : *mpp_exchange_receiver_map) - { - it.second->cancel(); - } + receiver_set->cancel(); } } @@ -393,7 +390,7 @@ void MPPTask::runImpl() else { context->getProcessList().sendCancelToQuery(context->getCurrentQueryId(), context->getClientInfo().current_user, true); - cancelAllExchangeReceivers(); + cancelAllReceivers(); writeErrToAllTunnels(err_msg); } LOG_FMT_INFO(log, "task ends, time cost is {} ms.", stopwatch.elapsedMilliseconds()); diff --git a/dbms/src/Flash/Mpp/MPPTask.h b/dbms/src/Flash/Mpp/MPPTask.h index c8423ac484c..d7e5ed169de 100644 --- a/dbms/src/Flash/Mpp/MPPTask.h +++ b/dbms/src/Flash/Mpp/MPPTask.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -109,7 +110,7 @@ class MPPTask : public std::enable_shared_from_this void initExchangeReceivers(); - void cancelAllExchangeReceivers(); + void cancelAllReceivers(); tipb::DAGRequest dag_req; @@ -126,8 +127,8 @@ class MPPTask : public std::enable_shared_from_this MPPTaskId id; MPPTunnelSetPtr tunnel_set; - /// key: executor_id of ExchangeReceiver nodes in dag. - ExchangeReceiverMapPtr mpp_exchange_receiver_map; + + MPPReceiverSetPtr receiver_set; int new_thread_count_of_exchange_receiver = 0;