Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] coprocessor: separate thread for encoding data for batch cop #5349

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions dbms/src/Flash/Coprocessor/DAGContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <Flash/Mpp/ExchangeReceiver.h>
#include <Flash/Statistics/traverseExecutors.h>
#include <Storages/Transaction/TMTContext.h>
#include <common/logger_useful.h>

namespace DB
{
Expand Down Expand Up @@ -86,25 +87,25 @@ void DAGContext::initExecutorIdToJoinIdMap()
{
// only mpp task has join executor
// for mpp, all executor has executor id.
if (isMPPTask())
{
executor_id_to_join_id_map.clear();
traverseExecutorsReverse(dag_request, [&](const tipb::Executor & executor) {
std::vector<String> all_join_id;
// for mpp, dag_request.has_root_executor() == true, can call `getChildren` directly.
getChildren(executor).forEach([&](const tipb::Executor & child) {
assert(child.has_executor_id());
auto child_it = executor_id_to_join_id_map.find(child.executor_id());
if (child_it != executor_id_to_join_id_map.end())
all_join_id.insert(all_join_id.end(), child_it->second.begin(), child_it->second.end());
});
assert(executor.has_executor_id());
if (executor.tp() == tipb::ExecType::TypeJoin)
all_join_id.push_back(executor.executor_id());
if (!all_join_id.empty())
executor_id_to_join_id_map[executor.executor_id()] = all_join_id;
if (!isMPPTask())
return;

executor_id_to_join_id_map.clear();
traverseExecutorsReverse(dag_request, [&](const tipb::Executor & executor) {
std::vector<String> all_join_id;
// for mpp, dag_request.has_root_executor() == true, can call `getChildren` directly.
getChildren(executor).forEach([&](const tipb::Executor & child) {
assert(child.has_executor_id());
auto child_it = executor_id_to_join_id_map.find(child.executor_id());
if (child_it != executor_id_to_join_id_map.end())
all_join_id.insert(all_join_id.end(), child_it->second.begin(), child_it->second.end());
});
}
assert(executor.has_executor_id());
if (executor.tp() == tipb::ExecType::TypeJoin)
all_join_id.push_back(executor.executor_id());
if (!all_join_id.empty())
executor_id_to_join_id_map[executor.executor_id()] = all_join_id;
});
}

std::unordered_map<String, std::vector<String>> & DAGContext::getExecutorIdToJoinIdMap()
Expand Down
7 changes: 4 additions & 3 deletions dbms/src/Flash/Coprocessor/DAGDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,17 @@ try
writer->Write(response);
}

auto streaming_writer = std::make_shared<StreamWriter>(writer);
const auto & settings = context.getSettingsRef();
auto streaming_writer = std::make_shared<StreamWriter>(writer, settings.max_threads);
TiDB::TiDBCollators collators;

std::unique_ptr<DAGResponseWriter> response_writer = std::make_unique<StreamingDAGResponseWriter<StreamWriterPtr, false>>(
streaming_writer,
std::vector<Int64>(),
collators,
tipb::ExchangeType::PassThrough,
context.getSettingsRef().dag_records_per_chunk,
context.getSettingsRef().batch_send_min_limit,
settings.dag_records_per_chunk,
settings.batch_send_min_limit,
true,
dag_context,
/*fine_grained_shuffle_stream_count=*/0,
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Coprocessor/DAGStorageInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <Storages/Transaction/LockException.h>
#include <Storages/Transaction/TMTContext.h>
#include <TiDB/Schema/SchemaSyncer.h>
#include <common/logger_useful.h>

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
Expand Down
158 changes: 158 additions & 0 deletions dbms/src/Flash/Coprocessor/StreamWriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright 2022 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Flash/Coprocessor/StreamWriter.h>
#include <common/logger_useful.h>

namespace DB
{
StreamWriter::StreamWriter(::grpc::ServerWriter<::coprocessor::BatchResponse> * writer_, UInt64 input_streams_num)
: writer(writer_)
, connected(true)
, finished(false)
, send_queue(std::max(5, input_streams_num * 5))
, thread_manager(newThreadManager())
, log(Logger::get("StreamWriter"))
{
thread_manager->schedule(true, "StreamWriter", [this] {
sendJob();
});
}

StreamWriter::~StreamWriter()
{
try
{
{
std::unique_lock lock(mu);
if (finished)
{
LOG_FMT_TRACE(log, "already finished!");
return;
}

// make sure to finish the stream writer after it is connected
waitUntilConnectedOrFinished(lock);
finishSendQueue();
}
LOG_FMT_TRACE(log, "waiting consumer finish!");
waitForConsumerFinish(/*allow_throw=*/false);
LOG_FMT_TRACE(log, "waiting child thread finished!");
thread_manager->wait();
}
catch (...)
{
tryLogCurrentException(log, "Error in destructor function of StreamWriter");
}
}

void StreamWriter::write(tipb::SelectResponse & response, [[maybe_unused]] uint16_t id)
{
auto rsp = std::make_shared<::coprocessor::BatchResponse>();
if (!response.SerializeToString(rsp->mutable_data()))
throw Exception(fmt::format("Fail to serialize response, response size: {}", response.ByteSizeLong()));

{
std::unique_lock lock(mu);
waitUntilConnectedOrFinished(lock);
if (finished)
throw Exception("write to StreamWriter which is already closed, " + consumer_state.getError());

if (send_queue.push(rsp))
{
connection_profile_info.bytes += rsp->ByteSizeLong();
connection_profile_info.packets += 1;
return;
}
}
// push failed, wait consumer for the final state
waitForConsumerFinish(/*allow_throw=*/true);
}

void StreamWriter::writeDone()
{
LOG_FMT_TRACE(log, "ready to finish");
{
std::unique_lock lock(mu);
if (finished)
throw Exception("write to StreamWriter which is already closed, " + consumer_state.getError());
waitUntilConnectedOrFinished(lock);
finishSendQueue();
}
waitForConsumerFinish(/*allow_throw=*/true);
}

void StreamWriter::sendJob()
{
String err_msg;
try
{
BatchResponsePtr rsp;
while (send_queue.pop(rsp))
{
if (!writer->Write(*rsp))
{
err_msg = "grpc writes failed.";
break;
}
}
}
catch (Exception & e)
{
err_msg = e.message();
}
catch (std::exception & e)
{
err_msg = e.what();
}
catch (...)
{
err_msg = fmt::format("fatal error in {}", __PRETTY_FUNCTION__);
}
if (!err_msg.empty())
LOG_ERROR(log, err_msg);
consumerFinish(err_msg);
}

void StreamWriter::waitForConsumerFinish(bool allow_throw)
{
LOG_FMT_TRACE(log, "start wait for consumer finish!");
String err_msg = consumer_state.getError(); // may blocking
if (allow_throw && !err_msg.empty())
throw Exception("Consumer exits unexpected, " + err_msg);
LOG_FMT_TRACE(log, "end wait for consumer finish!");
}

void StreamWriter::consumerFinish(const String & err_msg)
{
// must finish send_queue outside of the critical area to avoid deadlock with write.
LOG_FMT_TRACE(log, "calling consumer finish");
send_queue.finish();
{
std::unique_lock lock(mu);
if (finished && consumer_state.errHasSet())
return;
finished = true;
consumer_state.setError(err_msg);
cv_for_finished.notify_all();
}
}

void StreamWriter::waitUntilConnectedOrFinished(std::unique_lock<std::mutex> & lk)
{
LOG_FMT_TRACE(log, "start waitUntilConnectedOrFinished");
cv_for_finished.wait(lk, [&] { return connected || finished; });
LOG_FMT_TRACE(log, "end waitUntilConnectedOrFinished");
}
} // namespace DB
108 changes: 90 additions & 18 deletions dbms/src/Flash/Coprocessor/StreamWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
#pragma once

#include <Common/Exception.h>
#include <Common/Logger.h>
#include <Common/MPMCQueue.h>
#include <Common/ThreadManager.h>
#include <Common/nocopyable.h>
#include <Flash/Statistics/ConnectionProfileInfo.h>

#include <condition_variable>
#include <exception>
#include <future>
#include <mutex>

#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
Expand All @@ -25,7 +36,6 @@
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#include <mutex>

namespace mpp
{
Expand All @@ -36,31 +46,93 @@ namespace DB
{
struct StreamWriter
{
::grpc::ServerWriter<::coprocessor::BatchResponse> * writer;
std::mutex write_mutex;
explicit StreamWriter(::grpc::ServerWriter<::coprocessor::BatchResponse> * writer_, UInt64 input_streams_num);

~StreamWriter();

explicit StreamWriter(::grpc::ServerWriter<::coprocessor::BatchResponse> * writer_)
: writer(writer_)
{}
void write(mpp::MPPDataPacket &)
static void write(mpp::MPPDataPacket &)
{
throw Exception("StreamWriter::write(mpp::MPPDataPacket &) do not support writing MPPDataPacket!");
}
void write(mpp::MPPDataPacket &, [[maybe_unused]] uint16_t)

static void write(mpp::MPPDataPacket &, uint16_t)
{
throw Exception("StreamWriter::write(mpp::MPPDataPacket &, [[maybe_unused]] uint16_t) do not support writing MPPDataPacket!");
throw Exception("StreamWriter::write(mpp::MPPDataPacket &, uint16_t) do not support writing MPPDataPacket!");
}
void write(tipb::SelectResponse & response, [[maybe_unused]] uint16_t id = 0)

void write(tipb::SelectResponse & response, uint16_t id = 0);

void writeDone();

// a helper function
static uint16_t getPartitionNum() { return 0; }

DISALLOW_COPY_AND_MOVE(StreamWriter);

private:
// work as a background task to keep sending packets until done.
void sendJob();

void waitForConsumerFinish(bool allow_throw);

void consumerFinish(const String & err_msg);

void finishSendQueue()
{
::coprocessor::BatchResponse resp;
if (!response.SerializeToString(resp.mutable_data()))
throw Exception("Fail to serialize response, response size: " + std::to_string(response.ByteSizeLong()));
std::lock_guard lk(write_mutex);
if (!writer->Write(resp))
throw Exception("Failed to write resp");
send_queue.finish();
}
// a helper function
uint16_t getPartitionNum() { return 0; }

void waitUntilConnectedOrFinished(std::unique_lock<std::mutex> & lk);

private:
std::mutex mu;
::grpc::ServerWriter<::coprocessor::BatchResponse> * writer;

std::condition_variable cv_for_finished;
bool connected;
bool finished;
using BatchResponsePtr = std::shared_ptr<::coprocessor::BatchResponse>;
MPMCQueue<BatchResponsePtr> send_queue;

std::shared_ptr<ThreadManager> thread_manager;

/// Consumer can be sendLoop or local receiver.
class ConsumerState
{
public:
ConsumerState()
: future(promise.get_future())
{
}

// before finished, must be called without protection of mu
String getError()
{
future.wait();
return future.get();
}

void setError(const String & err_msg)
{
promise.set_value(err_msg);
err_has_set = true;
}

bool errHasSet() const
{
return err_has_set.load();
}

private:
std::promise<String> promise;
std::shared_future<String> future;
std::atomic<bool> err_has_set{false};
};
ConsumerState consumer_state;

ConnectionProfileInfo connection_profile_info;

LoggerPtr log;
};

using StreamWriterPtr = std::shared_ptr<StreamWriter>;
Expand Down