Skip to content

Commit

Permalink
new thread for encoding batch cop
Browse files Browse the repository at this point in the history
Signed-off-by: JaySon-Huang <[email protected]>
  • Loading branch information
JaySon-Huang committed Jul 11, 2022
1 parent 34d7570 commit 08a7d13
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 39 deletions.
37 changes: 19 additions & 18 deletions dbms/src/Flash/Coprocessor/DAGContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <Flash/Mpp/ExchangeReceiver.h>
#include <Flash/Statistics/traverseExecutors.h>
#include <Storages/Transaction/TMTContext.h>
#include <common/logger_useful.h>

namespace DB
{
Expand Down Expand Up @@ -86,25 +87,25 @@ void DAGContext::initExecutorIdToJoinIdMap()
{
// only mpp task has join executor
// for mpp, all executor has executor id.
if (isMPPTask())
{
executor_id_to_join_id_map.clear();
traverseExecutorsReverse(dag_request, [&](const tipb::Executor & executor) {
std::vector<String> all_join_id;
// for mpp, dag_request.has_root_executor() == true, can call `getChildren` directly.
getChildren(executor).forEach([&](const tipb::Executor & child) {
assert(child.has_executor_id());
auto child_it = executor_id_to_join_id_map.find(child.executor_id());
if (child_it != executor_id_to_join_id_map.end())
all_join_id.insert(all_join_id.end(), child_it->second.begin(), child_it->second.end());
});
assert(executor.has_executor_id());
if (executor.tp() == tipb::ExecType::TypeJoin)
all_join_id.push_back(executor.executor_id());
if (!all_join_id.empty())
executor_id_to_join_id_map[executor.executor_id()] = all_join_id;
if (!isMPPTask())
return;

executor_id_to_join_id_map.clear();
traverseExecutorsReverse(dag_request, [&](const tipb::Executor & executor) {
std::vector<String> all_join_id;
// for mpp, dag_request.has_root_executor() == true, can call `getChildren` directly.
getChildren(executor).forEach([&](const tipb::Executor & child) {
assert(child.has_executor_id());
auto child_it = executor_id_to_join_id_map.find(child.executor_id());
if (child_it != executor_id_to_join_id_map.end())
all_join_id.insert(all_join_id.end(), child_it->second.begin(), child_it->second.end());
});
}
assert(executor.has_executor_id());
if (executor.tp() == tipb::ExecType::TypeJoin)
all_join_id.push_back(executor.executor_id());
if (!all_join_id.empty())
executor_id_to_join_id_map[executor.executor_id()] = all_join_id;
});
}

std::unordered_map<String, std::vector<String>> & DAGContext::getExecutorIdToJoinIdMap()
Expand Down
7 changes: 4 additions & 3 deletions dbms/src/Flash/Coprocessor/DAGDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,17 @@ try
writer->Write(response);
}

auto streaming_writer = std::make_shared<StreamWriter>(writer);
const auto & settings = context.getSettingsRef();
auto streaming_writer = std::make_shared<StreamWriter>(writer, settings.max_threads);
TiDB::TiDBCollators collators;

std::unique_ptr<DAGResponseWriter> response_writer = std::make_unique<StreamingDAGResponseWriter<StreamWriterPtr, false>>(
streaming_writer,
std::vector<Int64>(),
collators,
tipb::ExchangeType::PassThrough,
context.getSettingsRef().dag_records_per_chunk,
context.getSettingsRef().batch_send_min_limit,
settings.dag_records_per_chunk,
settings.batch_send_min_limit,
true,
dag_context,
/*fine_grained_shuffle_stream_count=*/0,
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <Storages/Transaction/LockException.h>
#include <Storages/Transaction/TMTContext.h>
#include <TiDB/Schema/SchemaSyncer.h>
#include <common/logger_useful.h>

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
Expand Down
143 changes: 143 additions & 0 deletions dbms/src/Flash/Coprocessor/StreamWriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#include <Flash/Coprocessor/StreamWriter.h>
#include <common/logger_useful.h>

namespace DB
{
StreamWriter::StreamWriter(::grpc::ServerWriter<::coprocessor::BatchResponse> * writer_, UInt64 input_streams_num)
: writer(writer_)
, connected(true)
, finished(false)
, send_queue(std::max(5, input_streams_num * 5))
, thread_manager(newThreadManager())
, log(Logger::get("StreamWriter"))
{
thread_manager->schedule(true, "StreamWriter", [this] {
sendJob();
});
}

StreamWriter::~StreamWriter()
{
try
{
{
std::unique_lock lock(write_mutex);
if (finished)
{
LOG_FMT_TRACE(log, "already finished!");
return;
}

// make sure to finish the stream writer after it is connected
waitUntilConnectedOrFinished(lock);
finishSendQueue();
}
LOG_FMT_TRACE(log, "waiting consumer finish!");
waitForConsumerFinish(/*allow_throw=*/false);
LOG_FMT_TRACE(log, "waiting child thread finished!");
thread_manager->wait();
}
catch (...)
{
tryLogCurrentException(log, "Error in destructor function of StreamWriter");
}
}

void StreamWriter::write(tipb::SelectResponse & response, [[maybe_unused]] uint16_t id)
{
auto rsp = std::make_shared<::coprocessor::BatchResponse>();
if (!response.SerializeToString(rsp->mutable_data()))
throw Exception(fmt::format("Fail to serialize response, response size: {}", response.ByteSizeLong()));

{
std::unique_lock lock(write_mutex);
waitUntilConnectedOrFinished(lock);
if (finished)
throw Exception("write to StreamWriter which is already closed, " + consumer_state.getError());

if (send_queue.push(rsp))
{
connection_profile_info.bytes += rsp->ByteSizeLong();
connection_profile_info.packets += 1;
return;
}
}
// push failed, wait consumer for the final state
waitForConsumerFinish(/*allow_throw=*/true);
}

void StreamWriter::writeDone()
{
LOG_FMT_TRACE(log, "ready to finish");
{
std::unique_lock lk(write_mutex);
if (finished)
throw Exception("write to StreamWriter which is already closed, " + consumer_state.getError());
waitUntilConnectedOrFinished(lk);
finishSendQueue();
}
waitForConsumerFinish(/*allow_throw=*/true);
}

void StreamWriter::sendJob()
{
String err_msg;
try
{
BatchResponsePtr rsp;
while (send_queue.pop(rsp))
{
if (!writer->Write(*rsp))
{
err_msg = "grpc writes failed.";
break;
}
}
}
catch (Exception & e)
{
err_msg = e.message();
}
catch (std::exception & e)
{
err_msg = e.what();
}
catch (...)
{
err_msg = fmt::format("fatal error in {}", __PRETTY_FUNCTION__);
}
if (!err_msg.empty())
LOG_ERROR(log, err_msg);
consumerFinish(err_msg);
}

void StreamWriter::waitForConsumerFinish(bool allow_throw)
{
LOG_FMT_TRACE(log, "start wait for consumer finish!");
String err_msg = consumer_state.getError();
if (allow_throw && !err_msg.empty())
throw Exception("Consumer exits unexpected, " + err_msg);
LOG_FMT_TRACE(log, "end wait for consumer finish!");
}

void StreamWriter::consumerFinish(const String & err_msg)
{
LOG_FMT_TRACE(log, "calling consumer finish");
send_queue.finish();
{
std::unique_lock lock(write_mutex);
if (finished && consumer_state.errHasSet())
return;
finished = true;
consumer_state.setError(err_msg);
cv_for_finished.notify_all();
}
}

void StreamWriter::waitUntilConnectedOrFinished(std::unique_lock<std::mutex> & lk)
{
LOG_FMT_TRACE(log, "start waitUntilConnectedOrFinished");
cv_for_finished.wait(lk, [&] { return connected || finished; });
LOG_FMT_TRACE(log, "end waitUntilConnectedOrFinished");
}
} // namespace DB
108 changes: 90 additions & 18 deletions dbms/src/Flash/Coprocessor/StreamWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
#pragma once

#include <Common/Exception.h>
#include <Common/Logger.h>
#include <Common/MPMCQueue.h>
#include <Common/ThreadManager.h>
#include <Common/nocopyable.h>
#include <Flash/Statistics/ConnectionProfileInfo.h>

#include <condition_variable>
#include <exception>
#include <future>
#include <mutex>

#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
Expand All @@ -25,7 +36,6 @@
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#include <mutex>

namespace mpp
{
Expand All @@ -36,31 +46,93 @@ namespace DB
{
struct StreamWriter
{
::grpc::ServerWriter<::coprocessor::BatchResponse> * writer;
std::mutex write_mutex;
explicit StreamWriter(::grpc::ServerWriter<::coprocessor::BatchResponse> * writer_, UInt64 input_streams_num);

explicit StreamWriter(::grpc::ServerWriter<::coprocessor::BatchResponse> * writer_)
: writer(writer_)
{}
void write(mpp::MPPDataPacket &)
~StreamWriter();

static void write(mpp::MPPDataPacket &)
{
throw Exception("StreamWriter::write(mpp::MPPDataPacket &) do not support writing MPPDataPacket!");
}
void write(mpp::MPPDataPacket &, [[maybe_unused]] uint16_t)

static void write(mpp::MPPDataPacket &, uint16_t)
{
throw Exception("StreamWriter::write(mpp::MPPDataPacket &, [[maybe_unused]] uint16_t) do not support writing MPPDataPacket!");
throw Exception("StreamWriter::write(mpp::MPPDataPacket &, uint16_t) do not support writing MPPDataPacket!");
}
void write(tipb::SelectResponse & response, [[maybe_unused]] uint16_t id = 0)

void write(tipb::SelectResponse & response, uint16_t id = 0);

void writeDone();

// a helper function
static uint16_t getPartitionNum() { return 0; }

DISALLOW_COPY_AND_MOVE(StreamWriter);

private:
// work as a background task to keep sending packets until done.
void sendJob();

void waitForConsumerFinish(bool allow_throw);

void consumerFinish(const String & err_msg);

void finishSendQueue()
{
::coprocessor::BatchResponse resp;
if (!response.SerializeToString(resp.mutable_data()))
throw Exception("Fail to serialize response, response size: " + std::to_string(response.ByteSizeLong()));
std::lock_guard lk(write_mutex);
if (!writer->Write(resp))
throw Exception("Failed to write resp");
send_queue.finish();
}
// a helper function
uint16_t getPartitionNum() { return 0; }

void waitUntilConnectedOrFinished(std::unique_lock<std::mutex> & lk);

private:
std::mutex write_mutex;
::grpc::ServerWriter<::coprocessor::BatchResponse> * writer;

std::condition_variable cv_for_finished;
bool connected;
bool finished;
using BatchResponsePtr = std::shared_ptr<::coprocessor::BatchResponse>;
MPMCQueue<BatchResponsePtr> send_queue;

std::shared_ptr<ThreadManager> thread_manager;

/// Consumer can be sendLoop or local receiver.
class ConsumerState
{
public:
ConsumerState()
: future(promise.get_future())
{
}

// before finished, must be called without protection of mu
String getError()
{
future.wait();
return future.get();
}

void setError(const String & err_msg)
{
promise.set_value(err_msg);
err_has_set = true;
}

bool errHasSet() const
{
return err_has_set.load();
}

private:
std::promise<String> promise;
std::shared_future<String> future;
std::atomic<bool> err_has_set{false};
};
ConsumerState consumer_state;

ConnectionProfileInfo connection_profile_info;

LoggerPtr log;
};

using StreamWriterPtr = std::shared_ptr<StreamWriter>;
Expand Down

0 comments on commit 08a7d13

Please sign in to comment.