Skip to content

Commit

Permalink
move tunnel_map to MPPTunnelSet
Browse files Browse the repository at this point in the history
Signed-off-by: xufei <[email protected]>
  • Loading branch information
windtalker committed Jun 10, 2022
1 parent 1e3207d commit 193de75
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 35 deletions.
45 changes: 14 additions & 31 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,34 +78,28 @@ MPPTask::~MPPTask()

void MPPTask::closeAllTunnels(const String & reason)
{
for (auto & it : tunnel_map)
{
it.second->close(reason);
}
if (likely(tunnel_set))
tunnel_set->close(reason);
}

void MPPTask::finishWrite()
{
for (const auto & it : tunnel_map)
{
it.second->writeDone();
}
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->finishWrite();
}

void MPPTask::run()
{
newThreadManager()->scheduleThenDetach(true, "MPPTask", [self = shared_from_this()] { self->runImpl(); });
}

void MPPTask::registerTunnel(const MPPTaskId & id, MPPTunnelPtr tunnel)
void MPPTask::registerTunnel(const MPPTaskId & task_id, MPPTunnelPtr tunnel)
{
if (status == CANCELLED)
throw Exception("the tunnel " + tunnel->id() + " can not been registered, because the task is cancelled");

if (tunnel_map.find(id) != tunnel_map.end())
throw Exception("the tunnel " + tunnel->id() + " has been registered");

tunnel_map[id] = tunnel;
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->registerTunnel(task_id, tunnel);
}

std::pair<MPPTunnelPtr, String> MPPTask::getTunnel(const ::mpp::EstablishMPPConnectionRequest * request)
Expand All @@ -120,16 +114,17 @@ std::pair<MPPTunnelPtr, String> MPPTask::getTunnel(const ::mpp::EstablishMPPConn
}

MPPTaskId receiver_id{request->receiver_meta().start_ts(), request->receiver_meta().task_id()};
auto it = tunnel_map.find(receiver_id);
if (it == tunnel_map.end())
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
auto tunnel_ptr = tunnel_set->getTunnelById(receiver_id);
if (tunnel_ptr == nullptr)
{
auto err_msg = fmt::format(
"can't find tunnel ({} + {})",
request->sender_meta().task_id(),
request->receiver_meta().task_id());
return {nullptr, err_msg};
}
return {it->second, ""};
return {tunnel_ptr, ""};
}

void MPPTask::unregisterTask()
Expand Down Expand Up @@ -211,7 +206,7 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)
}

// register tunnels
tunnel_set = std::make_shared<MPPTunnelSet>();
tunnel_set = std::make_shared<MPPTunnelSet>(log->identifier());
std::chrono::seconds timeout(task_request.timeout());

for (int i = 0; i < exchange_sender.encoded_task_meta_size(); i++)
Expand All @@ -225,7 +220,6 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)
MPPTunnelPtr tunnel = std::make_shared<MPPTunnel>(task_meta, task_request.meta(), timeout, context->getSettingsRef().max_threads, is_local, is_async, log->identifier());
LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id());
registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel);
tunnel_set->addTunnel(tunnel);
if (!dag_context->isRootMPPTask())
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_register_tunnel_for_non_root_mpp_task);
Expand Down Expand Up @@ -369,19 +363,8 @@ void MPPTask::runImpl()

void MPPTask::writeErrToAllTunnels(const String & e)
{
for (auto & it : tunnel_map)
{
try
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_write_err_to_tunnel);
it.second->write(getPacketWithError(e), true);
}
catch (...)
{
it.second->close("Failed to write error msg to tunnel");
tryLogCurrentException(log, "Failed to write error " + e + " to tunnel: " + it.second->id());
}
}
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->writeError(e);
}

void MPPTask::cancel(const String & reason)
Expand Down
3 changes: 0 additions & 3 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

MPPTunnelSetPtr tunnel_set;

// which targeted task we should send data by which tunnel.
std::unordered_map<MPPTaskId, MPPTunnelPtr> tunnel_map;

MPPTaskManager * manager = nullptr;

const LoggerPtr log;
Expand Down
51 changes: 51 additions & 0 deletions dbms/src/Flash/Mpp/MPPTunnelSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@
// limitations under the License.

#include <Common/Exception.h>
#include <Common/FailPoint.h>
#include <Flash/Mpp/MPPTunnelSet.h>
#include <Flash/Mpp/Utils.h>
#include <fmt/core.h>

namespace DB
{
namespace FailPoints
{
extern const char exception_during_mpp_write_err_to_tunnel[];
} // namespace FailPoints
namespace
{
inline mpp::MPPDataPacket serializeToPacket(const tipb::SelectResponse & response)
Expand Down Expand Up @@ -108,6 +114,51 @@ void MPPTunnelSetBase<Tunnel>::write(mpp::MPPDataPacket & packet, int16_t partit
tunnels[partition_id]->write(packet);
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::writeError(const String & msg)
{
for (auto & tunnel : tunnels)
{
try
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_write_err_to_tunnel);
tunnel->write(getPacketWithError(msg), true);
}
catch (...)
{
tryLogCurrentException(log, "Failed to write error " + msg + " to tunnel: " + tunnel->id());
tunnel->close("Failed to write error msg to tunnel");
}
}
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::close(const String & reason)
{
for (auto & tunnel : tunnels)
tunnel->close(reason);
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::finishWrite()
{
for (auto & tunnel : tunnels)
{
tunnel->writeDone();
}
}

template <typename Tunnel>
typename MPPTunnelSetBase<Tunnel>::TunnelPtr MPPTunnelSetBase<Tunnel>::getTunnelById(const MPPTaskId & id)
{
auto it = id_to_index_map.find(id);
if (it == id_to_index_map.end())
{
return nullptr;
}
return tunnels[it->second];
}

/// Explicit template instantiations - to avoid code bloat in headers.
template class MPPTunnelSetBase<MPPTunnel>;

Expand Down
16 changes: 15 additions & 1 deletion dbms/src/Flash/Mpp/MPPTunnelSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <Flash/Mpp/MPPTaskId.h>
#include <Flash/Mpp/MPPTunnel.h>
#ifdef __clang__
#pragma clang diagnostic push
Expand All @@ -32,6 +33,9 @@ class MPPTunnelSetBase : private boost::noncopyable
{
public:
using TunnelPtr = std::shared_ptr<Tunnel>;
MPPTunnelSetBase(const String & req_id)
: log(Logger::get("MPPTunnelSet", req_id))
{}

void clearExecutionSummaries(tipb::SelectResponse & response);

Expand All @@ -50,11 +54,19 @@ class MPPTunnelSetBase : private boost::noncopyable
// this is a partition writing.
void write(tipb::SelectResponse & response, int16_t partition_id);
void write(mpp::MPPDataPacket & packet, int16_t partition_id);
void writeError(const String & msg);
void close(const String & reason);
void finishWrite();
TunnelPtr getTunnelById(const MPPTaskId & id);

uint16_t getPartitionNum() const { return tunnels.size(); }

void addTunnel(const TunnelPtr & tunnel)
void registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel)
{
if (id_to_index_map.find(id) != id_to_index_map.end())
throw Exception("the tunnel " + tunnel->id() + " has been registered");

id_to_index_map[id] = tunnels.size();
tunnels.push_back(tunnel);
if (!tunnel->isLocal())
{
Expand All @@ -71,6 +83,8 @@ class MPPTunnelSetBase : private boost::noncopyable

private:
std::vector<TunnelPtr> tunnels;
std::unordered_map<MPPTaskId, size_t> id_to_index_map;
const LoggerPtr log;

int remote_tunnel_cnt = 0;
};
Expand Down

0 comments on commit 193de75

Please sign in to comment.