diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 0f18ad582b4..8f9ca8e55e5 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -56,6 +56,7 @@ MPPTask::MPPTask(const mpp::TaskMeta & meta_, const ContextPtr & context_) , id(meta.start_ts(), meta.task_id()) , log(Logger::get("MPPTask", id.toString())) , mpp_task_statistics(id, meta.address()) + , needed_threads(0) , schedule_state(ScheduleState::WAITING) {} @@ -78,18 +79,14 @@ MPPTask::~MPPTask() void MPPTask::closeAllTunnels(const String & reason) { - for (auto & it : tunnel_map) - { - it.second->close(reason); - } + if (likely(tunnel_set)) + tunnel_set->close(reason); } void MPPTask::finishWrite() { - for (const auto & it : tunnel_map) - { - it.second->writeDone(); - } + RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set"); + tunnel_set->finishWrite(); } void MPPTask::run() @@ -97,15 +94,13 @@ void MPPTask::run() newThreadManager()->scheduleThenDetach(true, "MPPTask", [self = shared_from_this()] { self->runImpl(); }); } -void MPPTask::registerTunnel(const MPPTaskId & id, MPPTunnelPtr tunnel) +void MPPTask::registerTunnel(const MPPTaskId & task_id, MPPTunnelPtr tunnel) { if (status == CANCELLED) throw Exception("the tunnel " + tunnel->id() + " can not been registered, because the task is cancelled"); - if (tunnel_map.find(id) != tunnel_map.end()) - throw Exception("the tunnel " + tunnel->id() + " has been registered"); - - tunnel_map[id] = tunnel; + RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set"); + tunnel_set->registerTunnel(task_id, tunnel); } std::pair MPPTask::getTunnel(const ::mpp::EstablishMPPConnectionRequest * request) @@ -120,8 +115,9 @@ std::pair MPPTask::getTunnel(const ::mpp::EstablishMPPConn } MPPTaskId receiver_id{request->receiver_meta().start_ts(), request->receiver_meta().task_id()}; - auto it = tunnel_map.find(receiver_id); - if (it == tunnel_map.end()) + RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set"); + auto tunnel_ptr = tunnel_set->getTunnelById(receiver_id); + if (tunnel_ptr == nullptr) { auto err_msg = fmt::format( "can't find tunnel ({} + {})", @@ -129,7 +125,7 @@ std::pair MPPTask::getTunnel(const ::mpp::EstablishMPPConn request->receiver_meta().task_id()); return {nullptr, err_msg}; } - return {it->second, ""}; + return {tunnel_ptr, ""}; } void MPPTask::unregisterTask() @@ -211,7 +207,7 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request) } // register tunnels - tunnel_set = std::make_shared(); + tunnel_set = std::make_shared(log->identifier()); std::chrono::seconds timeout(task_request.timeout()); for (int i = 0; i < exchange_sender.encoded_task_meta_size(); i++) @@ -225,7 +221,6 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request) MPPTunnelPtr tunnel = std::make_shared(task_meta, task_request.meta(), timeout, context->getSettingsRef().max_threads, is_local, is_async, log->identifier()); LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id()); registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel); - tunnel_set->addTunnel(tunnel); if (!dag_context->isRootMPPTask()) { FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_register_tunnel_for_non_root_mpp_task); @@ -369,19 +364,8 @@ void MPPTask::runImpl() void MPPTask::writeErrToAllTunnels(const String & e) { - for (auto & it : tunnel_map) - { - try - { - FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_write_err_to_tunnel); - it.second->write(getPacketWithError(e), true); - } - catch (...) - { - it.second->close("Failed to write error msg to tunnel"); - tryLogCurrentException(log, "Failed to write error " + e + " to tunnel: " + it.second->id()); - } - } + RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set"); + tunnel_set->writeError(e); } void MPPTask::cancel(const String & reason) diff --git a/dbms/src/Flash/Mpp/MPPTask.h b/dbms/src/Flash/Mpp/MPPTask.h index c34cae49699..ee434a2f2ff 100644 --- a/dbms/src/Flash/Mpp/MPPTask.h +++ b/dbms/src/Flash/Mpp/MPPTask.h @@ -123,9 +123,6 @@ class MPPTask : public std::enable_shared_from_this MPPTunnelSetPtr tunnel_set; - // which targeted task we should send data by which tunnel. - std::unordered_map tunnel_map; - MPPTaskManager * manager = nullptr; const LoggerPtr log; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp index 12de07d4a18..500e9501b08 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp @@ -13,11 +13,17 @@ // limitations under the License. #include +#include #include +#include #include namespace DB { +namespace FailPoints +{ +extern const char exception_during_mpp_write_err_to_tunnel[]; +} // namespace FailPoints namespace { inline mpp::MPPDataPacket serializeToPacket(const tipb::SelectResponse & response) @@ -108,6 +114,65 @@ void MPPTunnelSetBase::write(mpp::MPPDataPacket & packet, int16_t partit tunnels[partition_id]->write(packet); } +template +void MPPTunnelSetBase::writeError(const String & msg) +{ + for (auto & tunnel : tunnels) + { + try + { + FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_write_err_to_tunnel); + tunnel->write(getPacketWithError(msg), true); + } + catch (...) + { + tunnel->close("Failed to write error msg to tunnel"); + tryLogCurrentException(log, "Failed to write error " + msg + " to tunnel: " + tunnel->id()); + } + } +} + +template +void MPPTunnelSetBase::registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel) +{ + if (id_to_index_map.find(id) != id_to_index_map.end()) + throw Exception("the tunnel " + tunnel->id() + " has been registered"); + + id_to_index_map[id] = tunnels.size(); + tunnels.push_back(tunnel); + if (!tunnel->isLocal()) + { + remote_tunnel_cnt++; + } +} + +template +void MPPTunnelSetBase::close(const String & reason) +{ + for (auto & tunnel : tunnels) + tunnel->close(reason); +} + +template +void MPPTunnelSetBase::finishWrite() +{ + for (auto & tunnel : tunnels) + { + tunnel->writeDone(); + } +} + +template +typename MPPTunnelSetBase::TunnelPtr MPPTunnelSetBase::getTunnelById(const MPPTaskId & id) +{ + auto it = id_to_index_map.find(id); + if (it == id_to_index_map.end()) + { + return nullptr; + } + return tunnels[it->second]; +} + /// Explicit template instantiations - to avoid code bloat in headers. template class MPPTunnelSetBase; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.h b/dbms/src/Flash/Mpp/MPPTunnelSet.h index f2279b945cb..021c609f516 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.h @@ -14,6 +14,7 @@ #pragma once +#include #include #ifdef __clang__ #pragma clang diagnostic push @@ -32,6 +33,9 @@ class MPPTunnelSetBase : private boost::noncopyable { public: using TunnelPtr = std::shared_ptr; + explicit MPPTunnelSetBase(const String & req_id) + : log(Logger::get("MPPTunnelSet", req_id)) + {} void clearExecutionSummaries(tipb::SelectResponse & response); @@ -50,17 +54,14 @@ class MPPTunnelSetBase : private boost::noncopyable // this is a partition writing. void write(tipb::SelectResponse & response, int16_t partition_id); void write(mpp::MPPDataPacket & packet, int16_t partition_id); + void writeError(const String & msg); + void close(const String & reason); + void finishWrite(); + void registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel); - uint16_t getPartitionNum() const { return tunnels.size(); } + TunnelPtr getTunnelById(const MPPTaskId & id); - void addTunnel(const TunnelPtr & tunnel) - { - tunnels.push_back(tunnel); - if (!tunnel->isLocal()) - { - remote_tunnel_cnt++; - } - } + uint16_t getPartitionNum() const { return tunnels.size(); } int getRemoteTunnelCnt() { @@ -71,6 +72,8 @@ class MPPTunnelSetBase : private boost::noncopyable private: std::vector tunnels; + std::unordered_map id_to_index_map; + const LoggerPtr log; int remote_tunnel_cnt = 0; };