Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move tunnel_map to MPPTunnelSet #5123

Merged
merged 4 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 15 additions & 31 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ MPPTask::MPPTask(const mpp::TaskMeta & meta_, const ContextPtr & context_)
, id(meta.start_ts(), meta.task_id())
, log(Logger::get("MPPTask", id.toString()))
, mpp_task_statistics(id, meta.address())
, needed_threads(0)
, schedule_state(ScheduleState::WAITING)
{}

Expand All @@ -78,34 +79,28 @@ MPPTask::~MPPTask()

void MPPTask::closeAllTunnels(const String & reason)
{
for (auto & it : tunnel_map)
{
it.second->close(reason);
}
if (likely(tunnel_set))
tunnel_set->close(reason);
}

void MPPTask::finishWrite()
{
for (const auto & it : tunnel_map)
{
it.second->writeDone();
}
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->finishWrite();
}

void MPPTask::run()
{
newThreadManager()->scheduleThenDetach(true, "MPPTask", [self = shared_from_this()] { self->runImpl(); });
}

void MPPTask::registerTunnel(const MPPTaskId & id, MPPTunnelPtr tunnel)
void MPPTask::registerTunnel(const MPPTaskId & task_id, MPPTunnelPtr tunnel)
{
if (status == CANCELLED)
throw Exception("the tunnel " + tunnel->id() + " can not been registered, because the task is cancelled");

if (tunnel_map.find(id) != tunnel_map.end())
throw Exception("the tunnel " + tunnel->id() + " has been registered");

tunnel_map[id] = tunnel;
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->registerTunnel(task_id, tunnel);
}

std::pair<MPPTunnelPtr, String> MPPTask::getTunnel(const ::mpp::EstablishMPPConnectionRequest * request)
Expand All @@ -120,16 +115,17 @@ std::pair<MPPTunnelPtr, String> MPPTask::getTunnel(const ::mpp::EstablishMPPConn
}

MPPTaskId receiver_id{request->receiver_meta().start_ts(), request->receiver_meta().task_id()};
auto it = tunnel_map.find(receiver_id);
if (it == tunnel_map.end())
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
auto tunnel_ptr = tunnel_set->getTunnelById(receiver_id);
if (tunnel_ptr == nullptr)
{
auto err_msg = fmt::format(
"can't find tunnel ({} + {})",
request->sender_meta().task_id(),
request->receiver_meta().task_id());
return {nullptr, err_msg};
}
return {it->second, ""};
return {tunnel_ptr, ""};
}

void MPPTask::unregisterTask()
Expand Down Expand Up @@ -211,7 +207,7 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)
}

// register tunnels
tunnel_set = std::make_shared<MPPTunnelSet>();
tunnel_set = std::make_shared<MPPTunnelSet>(log->identifier());
std::chrono::seconds timeout(task_request.timeout());

for (int i = 0; i < exchange_sender.encoded_task_meta_size(); i++)
Expand All @@ -225,7 +221,6 @@ void MPPTask::prepare(const mpp::DispatchTaskRequest & task_request)
MPPTunnelPtr tunnel = std::make_shared<MPPTunnel>(task_meta, task_request.meta(), timeout, context->getSettingsRef().max_threads, is_local, is_async, log->identifier());
LOG_FMT_DEBUG(log, "begin to register the tunnel {}", tunnel->id());
registerTunnel(MPPTaskId{task_meta.start_ts(), task_meta.task_id()}, tunnel);
tunnel_set->addTunnel(tunnel);
if (!dag_context->isRootMPPTask())
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_register_tunnel_for_non_root_mpp_task);
Expand Down Expand Up @@ -369,19 +364,8 @@ void MPPTask::runImpl()

void MPPTask::writeErrToAllTunnels(const String & e)
{
for (auto & it : tunnel_map)
{
try
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_write_err_to_tunnel);
it.second->write(getPacketWithError(e), true);
}
catch (...)
{
it.second->close("Failed to write error msg to tunnel");
tryLogCurrentException(log, "Failed to write error " + e + " to tunnel: " + it.second->id());
}
}
RUNTIME_ASSERT(tunnel_set != nullptr, log, "mpp task without tunnel set");
tunnel_set->writeError(e);
}

void MPPTask::cancel(const String & reason)
Expand Down
3 changes: 0 additions & 3 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,6 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

MPPTunnelSetPtr tunnel_set;

// which targeted task we should send data by which tunnel.
std::unordered_map<MPPTaskId, MPPTunnelPtr> tunnel_map;

MPPTaskManager * manager = nullptr;

const LoggerPtr log;
Expand Down
65 changes: 65 additions & 0 deletions dbms/src/Flash/Mpp/MPPTunnelSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@
// limitations under the License.

#include <Common/Exception.h>
#include <Common/FailPoint.h>
#include <Flash/Mpp/MPPTunnelSet.h>
#include <Flash/Mpp/Utils.h>
#include <fmt/core.h>

namespace DB
{
namespace FailPoints
{
extern const char exception_during_mpp_write_err_to_tunnel[];
} // namespace FailPoints
namespace
{
inline mpp::MPPDataPacket serializeToPacket(const tipb::SelectResponse & response)
Expand Down Expand Up @@ -108,6 +114,65 @@ void MPPTunnelSetBase<Tunnel>::write(mpp::MPPDataPacket & packet, int16_t partit
tunnels[partition_id]->write(packet);
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::writeError(const String & msg)
{
for (auto & tunnel : tunnels)
{
try
{
FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_write_err_to_tunnel);
tunnel->write(getPacketWithError(msg), true);
}
catch (...)
{
tunnel->close("Failed to write error msg to tunnel");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in original code, tunnel->close is the 1st statement, are they same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the original implementation.

tryLogCurrentException(log, "Failed to write error " + msg + " to tunnel: " + tunnel->id());
}
}
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel)
{
if (id_to_index_map.find(id) != id_to_index_map.end())
throw Exception("the tunnel " + tunnel->id() + " has been registered");

id_to_index_map[id] = tunnels.size();
tunnels.push_back(tunnel);
if (!tunnel->isLocal())
{
remote_tunnel_cnt++;
}
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::close(const String & reason)
{
for (auto & tunnel : tunnels)
tunnel->close(reason);
}

template <typename Tunnel>
void MPPTunnelSetBase<Tunnel>::finishWrite()
{
for (auto & tunnel : tunnels)
{
tunnel->writeDone();
}
}

template <typename Tunnel>
typename MPPTunnelSetBase<Tunnel>::TunnelPtr MPPTunnelSetBase<Tunnel>::getTunnelById(const MPPTaskId & id)
{
auto it = id_to_index_map.find(id);
if (it == id_to_index_map.end())
{
return nullptr;
}
return tunnels[it->second];
}

/// Explicit template instantiations - to avoid code bloat in headers.
template class MPPTunnelSetBase<MPPTunnel>;

Expand Down
21 changes: 12 additions & 9 deletions dbms/src/Flash/Mpp/MPPTunnelSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <Flash/Mpp/MPPTaskId.h>
#include <Flash/Mpp/MPPTunnel.h>
#ifdef __clang__
#pragma clang diagnostic push
Expand All @@ -32,6 +33,9 @@ class MPPTunnelSetBase : private boost::noncopyable
{
public:
using TunnelPtr = std::shared_ptr<Tunnel>;
explicit MPPTunnelSetBase(const String & req_id)
: log(Logger::get("MPPTunnelSet", req_id))
{}

void clearExecutionSummaries(tipb::SelectResponse & response);

Expand All @@ -50,17 +54,14 @@ class MPPTunnelSetBase : private boost::noncopyable
// this is a partition writing.
void write(tipb::SelectResponse & response, int16_t partition_id);
void write(mpp::MPPDataPacket & packet, int16_t partition_id);
void writeError(const String & msg);
void close(const String & reason);
void finishWrite();
void registerTunnel(const MPPTaskId & id, const TunnelPtr & tunnel);

uint16_t getPartitionNum() const { return tunnels.size(); }
TunnelPtr getTunnelById(const MPPTaskId & id);

void addTunnel(const TunnelPtr & tunnel)
{
tunnels.push_back(tunnel);
if (!tunnel->isLocal())
{
remote_tunnel_cnt++;
}
}
uint16_t getPartitionNum() const { return tunnels.size(); }

int getRemoteTunnelCnt()
{
Expand All @@ -71,6 +72,8 @@ class MPPTunnelSetBase : private boost::noncopyable

private:
std::vector<TunnelPtr> tunnels;
std::unordered_map<MPPTaskId, size_t> id_to_index_map;
bestwoody marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::unordered_map<MPPTaskId, size_t> id_to_index_map;
std::unordered_map<MPPTaskId, size_t> target_id_to_index_map;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, but the pr was already merged :(

const LoggerPtr log;

int remote_tunnel_cnt = 0;
};
Expand Down