diff --git a/dbms/src/Flash/Coprocessor/DAGContext.cpp b/dbms/src/Flash/Coprocessor/DAGContext.cpp index 47c3a3b2450..6167090194a 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.cpp +++ b/dbms/src/Flash/Coprocessor/DAGContext.cpp @@ -223,8 +223,12 @@ void DAGContext::addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_r { if (!isMPPTask()) return; - RUNTIME_ASSERT(mpp_receiver_set != nullptr, log, "MPPTask without receiver set"); - return mpp_receiver_set->addCoprocessorReader(coprocessor_reader); + coprocessor_readers.push_back(coprocessor_reader); +} + +std::vector & DAGContext::getCoprocessorReaders() +{ + return coprocessor_readers; } bool DAGContext::containsRegionsInfoForTable(Int64 table_id) const diff --git a/dbms/src/Flash/Coprocessor/DAGContext.h b/dbms/src/Flash/Coprocessor/DAGContext.h index 8ea5a76df55..93e7edda7e8 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.h +++ b/dbms/src/Flash/Coprocessor/DAGContext.h @@ -306,6 +306,7 @@ class DAGContext mpp_receiver_set = receiver_set; } void addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader); + std::vector & getCoprocessorReaders(); void addSubquery(const String & subquery_id, SubqueryForSet && subquery); bool hasSubquery() const { return !subqueries.empty(); } @@ -373,6 +374,7 @@ class DAGContext std::atomic warning_count; MPPReceiverSetPtr mpp_receiver_set; + std::vector coprocessor_readers; /// vector of SubqueriesForSets(such as join build subquery). /// The order of the vector is also the order of the subquery. std::vector subqueries; diff --git a/dbms/src/Flash/EstablishCall.cpp b/dbms/src/Flash/EstablishCall.cpp index 62a10e3afb6..a9eb057364b 100644 --- a/dbms/src/Flash/EstablishCall.cpp +++ b/dbms/src/Flash/EstablishCall.cpp @@ -106,10 +106,8 @@ bool EstablishCallData::write(const mpp::MPPDataPacket & packet) void EstablishCallData::writeErr(const mpp::MPPDataPacket & packet) { state = ERR_HANDLE; - if (write(packet)) - err_status = grpc::Status::OK; - else - err_status = grpc::Status(grpc::StatusCode::UNKNOWN, "Write error message failed for unknown reason."); + err_status = grpc::Status::OK; + write(packet); } void EstablishCallData::setFinishState(const String & msg) diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 72a7c9d1435..5ea7b527e0f 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -96,10 +96,12 @@ void MPPTask::abortTunnels(const String & message, AbortType abort_type) void MPPTask::abortReceivers() { - if (likely(receiver_set != nullptr)) { - receiver_set->cancel(); + std::unique_lock lock(tunnel_and_receiver_mu); + if unlikely (receiver_set == nullptr) + return; } + receiver_set->cancel(); } void MPPTask::abortDataStreams(AbortType abort_type) @@ -111,8 +113,12 @@ void MPPTask::abortDataStreams(AbortType abort_type) void MPPTask::closeAllTunnels(const String & reason) { - if (likely(tunnel_set)) - tunnel_set->close(reason); + { + std::unique_lock lock(tunnel_and_receiver_mu); + if (unlikely(tunnel_set == nullptr)) + return; + } + tunnel_set->close(reason); } void MPPTask::finishWrite() @@ -128,7 +134,7 @@ void MPPTask::run() void MPPTask::registerTunnels(const mpp::DispatchTaskRequest & task_request) { - tunnel_set = std::make_shared(log->identifier()); + auto tunnel_set_local = std::make_shared(log->identifier()); std::chrono::seconds timeout(task_request.timeout()); const auto & exchange_sender = dag_req.root_executor().exchange_sender(); @@ -144,17 +150,24 @@ void MPPTask::registerTunnels(const mpp::DispatchTaskRequest & task_request) LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id()); if (status != INITIALIZING) throw Exception(fmt::format("The tunnel {} can not be registered, because the task is not in initializing state", tunnel->id())); - tunnel_set->registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel); + tunnel_set_local->registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel); if (!dag_context->isRootMPPTask()) { FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_register_tunnel_for_non_root_mpp_task); } } + { + std::unique_lock lock(tunnel_and_receiver_mu); + if (status != INITIALIZING) + throw Exception(fmt::format("The tunnels can not be registered, because the task is not in initializing state")); + tunnel_set = std::move(tunnel_set_local); + } + dag_context->tunnel_set = tunnel_set; } void MPPTask::initExchangeReceivers() { - receiver_set = std::make_shared(log->identifier()); + auto receiver_set_local = std::make_shared(log->identifier()); traverseExecutors(&dag_req, [&](const tipb::Executor & executor) { if (executor.tp() == tipb::ExecType::TypeExchangeReceiver) { @@ -177,11 +190,17 @@ void MPPTask::initExchangeReceivers() if (status != RUNNING) throw Exception("exchange receiver map can not be initialized, because the task is not in running state"); - receiver_set->addExchangeReceiver(executor_id, exchange_receiver); + receiver_set_local->addExchangeReceiver(executor_id, exchange_receiver); new_thread_count_of_exchange_receiver += exchange_receiver->computeNewThreadCount(); } return true; }); + { + std::unique_lock lock(tunnel_and_receiver_mu); + if (status != RUNNING) + throw Exception("exchange receiver map can not be initialized, because the task is not in running state"); + receiver_set = std::move(receiver_set_local); + } dag_context->setMPPReceiverSet(receiver_set); } @@ -293,7 +312,6 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request) // register tunnels registerTunnels(task_request); - dag_context->tunnel_set = tunnel_set; // register task. auto task_manager = tmt_context.getMPPTaskManager(); LOG_FMT_DEBUG(log, "begin to register the task {}", id.toString()); @@ -320,6 +338,13 @@ void MPPTask::preprocess() auto start_time = Clock::now(); initExchangeReceivers(); executeQuery(*context); + { + std::unique_lock lock(tunnel_and_receiver_mu); + if (status != RUNNING) + throw Exception("task not in running state, may be cancelled"); + for (auto & r : dag_context->getCoprocessorReaders()) + receiver_set->addCoprocessorReader(r); + } auto end_time = Clock::now(); dag_context->compile_time_ns = std::chrono::duration_cast(end_time - start_time).count(); mpp_task_statistics.setCompileTimestamp(start_time, end_time); diff --git a/dbms/src/Flash/Mpp/MPPTask.h b/dbms/src/Flash/Mpp/MPPTask.h index dfa9f8a2ea8..510178e2648 100644 --- a/dbms/src/Flash/Mpp/MPPTask.h +++ b/dbms/src/Flash/Mpp/MPPTask.h @@ -137,6 +137,8 @@ class MPPTask : public std::enable_shared_from_this MPPTaskId id; + std::mutex tunnel_and_receiver_mu; + MPPTunnelSetPtr tunnel_set; MPPReceiverSetPtr receiver_set; diff --git a/dbms/src/Flash/Mpp/MPPTunnel.cpp b/dbms/src/Flash/Mpp/MPPTunnel.cpp index e14d80aa5bd..8000e219e2f 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnel.cpp @@ -70,20 +70,7 @@ MPPTunnel::~MPPTunnel() }); try { - { - std::unique_lock lock(*mu); - if (status == TunnelStatus::Finished) - { - LOG_DEBUG(log, "already finished!"); - return; - } - - /// make sure to finish the tunnel after it is connected - waitUntilConnectedOrFinished(lock); - finishSendQueue(); - } - LOG_FMT_TRACE(log, "waiting consumer finish!"); - waitForSenderFinish(/*allow_throw=*/false); + close(""); } catch (...) {