Skip to content

Commit

Permalink
Some refinements of mpp_exchange_receiver_map and MPPTunnelSet (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
windtalker authored Jun 14, 2022
1 parent bcb837b commit 864cfe9
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 110 deletions.
51 changes: 2 additions & 49 deletions dbms/src/Flash/Coprocessor/DAGContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,60 +205,13 @@ void DAGContext::attachBlockIO(const BlockIO & io_)
{
io = io_;
}
void DAGContext::initExchangeReceiverIfMPP(Context & context, size_t max_streams)
{
if (isMPPTask())
{
if (mpp_exchange_receiver_map_inited)
throw TiFlashException("Repeatedly initialize mpp_exchange_receiver_map", Errors::Coprocessor::Internal);
traverseExecutors(dag_request, [&](const tipb::Executor & executor) {
if (executor.tp() == tipb::ExecType::TypeExchangeReceiver)
{
assert(executor.has_executor_id());
const auto & executor_id = executor.executor_id();
// In order to distinguish different exchange receivers.
auto exchange_receiver = std::make_shared<ExchangeReceiver>(
std::make_shared<GRPCReceiverContext>(
executor.exchange_receiver(),
getMPPTaskMeta(),
context.getTMTContext().getKVCluster(),
context.getTMTContext().getMPPTaskManager(),
context.getSettingsRef().enable_local_tunnel,
context.getSettingsRef().enable_async_grpc_client),
executor.exchange_receiver().encoded_task_meta_size(),
max_streams,
log->identifier(),
executor_id);
mpp_exchange_receiver_map[executor_id] = exchange_receiver;
new_thread_count_of_exchange_receiver += exchange_receiver->computeNewThreadCount();
}
return true;
});
mpp_exchange_receiver_map_inited = true;
}
}


const std::unordered_map<String, std::shared_ptr<ExchangeReceiver>> & DAGContext::getMPPExchangeReceiverMap() const
{
if (!isMPPTask())
throw TiFlashException("mpp_exchange_receiver_map is used in mpp only", Errors::Coprocessor::Internal);
if (!mpp_exchange_receiver_map_inited)
throw TiFlashException("mpp_exchange_receiver_map has not been initialized", Errors::Coprocessor::Internal);
return mpp_exchange_receiver_map;
}

void DAGContext::cancelAllExchangeReceiver()
{
for (auto & it : mpp_exchange_receiver_map)
{
it.second->cancel();
}
}

int DAGContext::getNewThreadCountOfExchangeReceiver() const
{
return new_thread_count_of_exchange_receiver;
RUNTIME_ASSERT(mpp_exchange_receiver_map != nullptr, log, "MPPTask without exchange receiver map");
return *mpp_exchange_receiver_map;
}

bool DAGContext::containsRegionsInfoForTable(Int64 table_id) const
Expand Down
14 changes: 7 additions & 7 deletions dbms/src/Flash/Coprocessor/DAGContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace DB
class Context;
class MPPTunnelSet;
class ExchangeReceiver;
using ExchangeReceiverMap = std::unordered_map<String, std::shared_ptr<ExchangeReceiver>>;
using ExchangeReceiverMapPtr = std::shared_ptr<std::unordered_map<String, std::shared_ptr<ExchangeReceiver>>>;

class Join;
using JoinPtr = std::shared_ptr<Join>;
Expand Down Expand Up @@ -254,7 +256,6 @@ class DAGContext
return io;
}

int getNewThreadCountOfExchangeReceiver() const;
UInt64 getFlags() const
{
return flags;
Expand Down Expand Up @@ -303,10 +304,11 @@ class DAGContext

bool columnsForTestEmpty() { return columns_for_test_map.empty(); }

void cancelAllExchangeReceiver();

void initExchangeReceiverIfMPP(Context & context, size_t max_streams);
const std::unordered_map<String, std::shared_ptr<ExchangeReceiver>> & getMPPExchangeReceiverMap() const;
void setMPPExchangeReceiverMap(ExchangeReceiverMapPtr & exchange_receiver_map)
{
mpp_exchange_receiver_map = exchange_receiver_map;
}

void addSubquery(const String & subquery_id, SubqueryForSet && subquery);
bool hasSubquery() const { return !subqueries.empty(); }
Expand Down Expand Up @@ -367,10 +369,8 @@ class DAGContext
ConcurrentBoundedQueue<tipb::Error> warnings;
/// warning_count is the actual warning count during the entire execution
std::atomic<UInt64> warning_count;
int new_thread_count_of_exchange_receiver = 0;
/// key: executor_id of ExchangeReceiver nodes in dag.
std::unordered_map<String, std::shared_ptr<ExchangeReceiver>> mpp_exchange_receiver_map;
bool mpp_exchange_receiver_map_inited = false;
ExchangeReceiverMapPtr mpp_exchange_receiver_map;
/// vector of SubqueriesForSets(such as join build subquery).
/// The order of the vector is also the order of the subquery.
std::vector<SubqueriesForSets> subqueries;
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,7 @@ tipb::EncodeType analyzeDAGEncodeType(DAGContext & dag_context)
return tipb::EncodeType::TypeDefault;
return encode_type;
}

tipb::ScalarFuncSig reverseGetFuncSigByFuncName(const String & name)
{
static std::unordered_map<String, tipb::ScalarFuncSig> func_name_sig_map = getFuncNameToSigMap();
Expand Down
17 changes: 1 addition & 16 deletions dbms/src/Flash/Coprocessor/InterpreterDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,8 @@ namespace DB
InterpreterDAG::InterpreterDAG(Context & context_, const DAGQuerySource & dag_)
: context(context_)
, dag(dag_)
, max_streams(context.getMaxStreams())
{
const Settings & settings = context.getSettingsRef();
if (dagContext().isBatchCop() || (dagContext().isMPPTask() && !dagContext().isTest()))
max_streams = settings.max_threads;
else if (dagContext().isTest())
max_streams = dagContext().initialize_concurrency;
else
max_streams = 1;

if (max_streams > 1)
{
max_streams *= settings.max_streams_to_max_threads_ratio;
}
}

void setRestorePipelineConcurrency(DAGQueryBlock & query_block)
Expand Down Expand Up @@ -75,10 +64,6 @@ BlockInputStreams InterpreterDAG::executeQueryBlock(DAGQueryBlock & query_block)

BlockIO InterpreterDAG::execute()
{
/// Due to learner read, DAGQueryBlockInterpreter may take a long time to build
/// the query plan, so we init mpp exchange receiver before executeQueryBlock
dagContext().initExchangeReceiverIfMPP(context, max_streams);

BlockInputStreams streams = executeQueryBlock(*dag.getRootQueryBlock());
DAGPipeline pipeline;
pipeline.streams = streams;
Expand Down
100 changes: 73 additions & 27 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
#include <Flash/Coprocessor/DAGCodec.h>
#include <Flash/Coprocessor/DAGUtils.h>
#include <Flash/CoprocessorHandler.h>
#include <Flash/Mpp/ExchangeReceiver.h>
#include <Flash/Mpp/GRPCReceiverContext.h>
#include <Flash/Mpp/MPPTask.h>
#include <Flash/Mpp/MPPTaskManager.h>
#include <Flash/Mpp/MPPTunnelSet.h>
#include <Flash/Mpp/MinTSOScheduler.h>
#include <Flash/Mpp/Utils.h>
#include <Flash/Statistics/traverseExecutors.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/executeQuery.h>
#include <Storages/Transaction/KVStore.h>
Expand Down Expand Up @@ -94,13 +97,73 @@ void MPPTask::run()
newThreadManager()->scheduleThenDetach(true, "MPPTask", [self = shared_from_this()] { self->runImpl(); });
}

void MPPTask::registerTunnel(const MPPTaskId & task_id, MPPTunnelPtr tunnel)
void MPPTask::registerTunnels(const mpp::DispatchTaskRequest & task_request)
{
if (status == CANCELLED)
throw Exception("the tunnel " + tunnel->id() + " can not been registered, because the task is cancelled");
tunnel_set = std::make_shared<MPPTunnelSet>(log->identifier());
std::chrono::seconds timeout(task_request.timeout());
const auto & exchange_sender = dag_req.root_executor().exchange_sender();

RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->registerTunnel(task_id, tunnel);
for (int i = 0; i < exchange_sender.encoded_task_meta_size(); ++i)
{
// exchange sender will register the tunnels and wait receiver to found a connection.
mpp::TaskMeta task_meta;
if (unlikely(!task_meta.ParseFromString(exchange_sender.encoded_task_meta(i))))
throw TiFlashException("Failed to decode task meta info in ExchangeSender", Errors::Coprocessor::BadRequest);
bool is_local = context->getSettingsRef().enable_local_tunnel && meta.address() == task_meta.address();
bool is_async = !is_local && context->getSettingsRef().enable_async_server;
MPPTunnelPtr tunnel = std::make_shared<MPPTunnel>(task_meta, task_request.meta(), timeout, context->getSettingsRef().max_threads, is_local, is_async, log->identifier());
LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id());
if (status != INITIALIZING)
throw Exception(fmt::format("The tunnel {} can not be registered, because the task is not in initializing state", tunnel->id()));
tunnel_set->registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel);
if (!dag_context->isRootMPPTask())
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_register_tunnel_for_non_root_mpp_task);
}
}
}

void MPPTask::initExchangeReceivers()
{
mpp_exchange_receiver_map = std::make_shared<ExchangeReceiverMap>();
traverseExecutors(&dag_req, [&](const tipb::Executor & executor) {
if (executor.tp() == tipb::ExecType::TypeExchangeReceiver)
{
assert(executor.has_executor_id());
const auto & executor_id = executor.executor_id();
// In order to distinguish different exchange receivers.
auto exchange_receiver = std::make_shared<ExchangeReceiver>(
std::make_shared<GRPCReceiverContext>(
executor.exchange_receiver(),
dag_context->getMPPTaskMeta(),
context->getTMTContext().getKVCluster(),
context->getTMTContext().getMPPTaskManager(),
context->getSettingsRef().enable_local_tunnel,
context->getSettingsRef().enable_async_grpc_client),
executor.exchange_receiver().encoded_task_meta_size(),
context->getMaxStreams(),
log->identifier(),
executor_id);
if (status != RUNNING)
throw Exception("exchange receiver map can not be initialized, because the task is not in running state");

(*mpp_exchange_receiver_map)[executor_id] = exchange_receiver;
new_thread_count_of_exchange_receiver += exchange_receiver->computeNewThreadCount();
}
return true;
});
dag_context->setMPPExchangeReceiverMap(mpp_exchange_receiver_map);
}

void MPPTask::cancelAllExchangeReceivers()
{
if (likely(mpp_exchange_receiver_map != nullptr))
{
for (auto & it : *mpp_exchange_receiver_map)
{
it.second->cancel();
}
}
}

std::pair<MPPTunnelPtr, String> MPPTask::getTunnel(const ::mpp::EstablishMPPConnectionRequest * request)
Expand All @@ -116,7 +179,7 @@ std::pair<MPPTunnelPtr, String> MPPTask::getTunnel(const ::mpp::EstablishMPPConn

MPPTaskId receiver_id{request->receiver_meta().start_ts(), request->receiver_meta().task_id()};
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
auto tunnel_ptr = tunnel_set->getTunnelById(receiver_id);
auto tunnel_ptr = tunnel_set->getTunnelByReceiverTaskId(receiver_id);
if (tunnel_ptr == nullptr)
{
auto err_msg = fmt::format(
Expand Down Expand Up @@ -207,25 +270,8 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)
}

// register tunnels
tunnel_set = std::make_shared<MPPTunnelSet>(log->identifier());
std::chrono::seconds timeout(task_request.timeout());
registerTunnels(task_request);

for (int i = 0; i < exchange_sender.encoded_task_meta_size(); i++)
{
// exchange sender will register the tunnels and wait receiver to found a connection.
mpp::TaskMeta task_meta;
if (!task_meta.ParseFromString(exchange_sender.encoded_task_meta(i)))
throw TiFlashException("Failed to decode task meta info in ExchangeSender", Errors::Coprocessor::BadRequest);
bool is_local = context->getSettingsRef().enable_local_tunnel && meta.address() == task_meta.address();
bool is_async = !is_local && context->getSettingsRef().enable_async_server;
MPPTunnelPtr tunnel = std::make_shared<MPPTunnel>(task_meta, task_request.meta(), timeout, context->getSettingsRef().max_threads, is_local, is_async, log->identifier());
LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id());
registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel);
if (!dag_context->isRootMPPTask())
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_register_tunnel_for_non_root_mpp_task);
}
}
dag_context->tunnel_set = tunnel_set;
// register task.
auto task_manager = tmt_context.getMPPTaskManager();
Expand All @@ -251,6 +297,7 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)
void MPPTask::preprocess()
{
auto start_time = Clock::now();
initExchangeReceivers();
DAGQuerySource dag(*context);
executeQuery(dag, *context, false, QueryProcessingStage::Complete);
auto end_time = Clock::now();
Expand Down Expand Up @@ -280,7 +327,7 @@ void MPPTask::runImpl()
LOG_FMT_INFO(log, "task starts preprocessing");
preprocess();
needed_threads = estimateCountOfNewThreads();
LOG_FMT_DEBUG(log, "Estimate new thread count of query :{} including tunnel_threads: {} , receiver_threads: {}", needed_threads, dag_context->tunnel_set->getRemoteTunnelCnt(), dag_context->getNewThreadCountOfExchangeReceiver());
LOG_FMT_DEBUG(log, "Estimate new thread count of query :{} including tunnel_threads: {} , receiver_threads: {}", needed_threads, dag_context->tunnel_set->getRemoteTunnelCnt(), new_thread_count_of_exchange_receiver);

scheduleOrWait();

Expand Down Expand Up @@ -346,8 +393,7 @@ void MPPTask::runImpl()
else
{
context->getProcessList().sendCancelToQuery(context->getCurrentQueryId(), context->getClientInfo().current_user, true);
if (dag_context)
dag_context->cancelAllExchangeReceiver();
cancelAllExchangeReceivers();
writeErrToAllTunnels(err_msg);
}
LOG_FMT_INFO(log, "task ends, time cost is {} ms.", stopwatch.elapsedMilliseconds());
Expand Down
12 changes: 10 additions & 2 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

void run();

void registerTunnel(const MPPTaskId & id, MPPTunnelPtr tunnel);

int getNeededThreads();

enum class ScheduleState
Expand Down Expand Up @@ -107,6 +105,12 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

int estimateCountOfNewThreads();

void registerTunnels(const mpp::DispatchTaskRequest & task_request);

void initExchangeReceivers();

void cancelAllExchangeReceivers();

tipb::DAGRequest dag_req;

ContextPtr context;
Expand All @@ -122,6 +126,10 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>
MPPTaskId id;

MPPTunnelSetPtr tunnel_set;
/// key: executor_id of ExchangeReceiver nodes in dag.
ExchangeReceiverMapPtr mpp_exchange_receiver_map;

int new_thread_count_of_exchange_receiver = 0;

MPPTaskManager * manager = nullptr;

Expand Down
14 changes: 7 additions & 7 deletions dbms/src/Flash/Mpp/MPPTunnelSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ void MPPTunnelSetBase<Tunnel>::writeError(const String & msg)
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel)
void MPPTunnelSetBase<Tunnel>::registerTunnel(const MPPTaskId & receiver_task_id, const TunnelPtr & tunnel)
{
if (id_to_index_map.find(id) != id_to_index_map.end())
throw Exception("the tunnel " + tunnel->id() + " has been registered");
if (receiver_task_id_to_index_map.find(receiver_task_id) != receiver_task_id_to_index_map.end())
throw Exception(fmt::format("the tunnel {} has been registered", tunnel->id()));

id_to_index_map[id] = tunnels.size();
receiver_task_id_to_index_map[receiver_task_id] = tunnels.size();
tunnels.push_back(tunnel);
if (!tunnel->isLocal())
{
Expand All @@ -163,10 +163,10 @@ void MPPTunnelSetBase<Tunnel>::finishWrite()
}

template <typename Tunnel>
typename MPPTunnelSetBase<Tunnel>::TunnelPtr MPPTunnelSetBase<Tunnel>::getTunnelById(const MPPTaskId & id)
typename MPPTunnelSetBase<Tunnel>::TunnelPtr MPPTunnelSetBase<Tunnel>::getTunnelByReceiverTaskId(const MPPTaskId & id)
{
auto it = id_to_index_map.find(id);
if (it == id_to_index_map.end())
auto it = receiver_task_id_to_index_map.find(id);
if (it == receiver_task_id_to_index_map.end())
{
return nullptr;
}
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Flash/Mpp/MPPTunnelSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class MPPTunnelSetBase : private boost::noncopyable
void finishWrite();
void registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel);

TunnelPtr getTunnelById(const MPPTaskId & id);
TunnelPtr getTunnelByReceiverTaskId(const MPPTaskId & id);

uint16_t getPartitionNum() const { return tunnels.size(); }

Expand All @@ -72,7 +72,7 @@ class MPPTunnelSetBase : private boost::noncopyable

private:
std::vector<TunnelPtr> tunnels;
std::unordered_map<MPPTaskId, size_t> id_to_index_map;
std::unordered_map<MPPTaskId, size_t> receiver_task_id_to_index_map;
const LoggerPtr log;

int remote_tunnel_cnt = 0;
Expand Down
Loading

0 comments on commit 864cfe9

Please sign in to comment.