Skip to content

Commit

Permalink
This is an automated cherry-pick of pingcap#9393
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <[email protected]>
  • Loading branch information
JinheLin authored and ti-chi-bot committed Nov 29, 2024
1 parent e98915d commit 70c7aed
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 1 deletion.
18 changes: 18 additions & 0 deletions dbms/src/Common/UniThreadPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ void ThreadPoolImpl<Thread>::setQueueSize(size_t value)
jobs.reserve(queue_size);
}

template <typename Thread>
size_t ThreadPoolImpl<Thread>::getQueueSize() const
{
std::lock_guard lock(mutex);
return queue_size;
}


template <typename Thread>
template <typename ReturnType>
Expand Down Expand Up @@ -199,6 +206,17 @@ void ThreadPoolImpl<Thread>::scheduleOrThrow(
}

template <typename Thread>
<<<<<<< HEAD
=======
std::future<void> ThreadPoolImpl<Thread>::scheduleWithFuture(Job job, uint64_t wait_timeout_us)
{
auto task = std::make_shared<std::packaged_task<void()>>(std::move(job));
scheduleImpl<void>([task]() { (*task)(); }, /*priority*/ 0, wait_timeout_us);
return task->get_future();
}

template <typename Thread>
>>>>>>> 738aade7e5 (Disagg: Set the waiting time for thread pool. (#9393))
void ThreadPoolImpl<Thread>::wait()
{
{
Expand Down
10 changes: 10 additions & 0 deletions dbms/src/Common/UniThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,24 @@ class ThreadPoolImpl
void scheduleOrThrowOnError(Job job, ssize_t priority = 0);

/// Similar to scheduleOrThrowOnError(...). Wait for specified amount of time and schedule a job or return false.
/// If wait_microseconds is zero, it means never wait.
bool trySchedule(Job job, ssize_t priority = 0, uint64_t wait_microseconds = 0) noexcept;

/// Similar to scheduleOrThrowOnError(...). Wait for specified amount of time and schedule a job or throw an exception.
/// If wait_microseconds is zero, it means never wait.
void scheduleOrThrow(
Job job,
ssize_t priority = 0,
uint64_t wait_microseconds = 0,
bool propagate_opentelemetry_tracing_context = true);

<<<<<<< HEAD
=======
/// Wrap job with std::packaged_task<void> and returns a std::future<void> object to check if the task has finished or thrown an exception.
/// If wait_microseconds is zero, it means never wait.
std::future<void> scheduleWithFuture(Job job, uint64_t wait_timeout_us = 0);

>>>>>>> 738aade7e5 (Disagg: Set the waiting time for thread pool. (#9393))
/// Wait for all currently active jobs to be done.
/// You may call schedule and wait many times in arbitrary order.
/// If any thread was throw an exception, first exception will be rethrown from this method,
Expand All @@ -104,6 +113,7 @@ class ThreadPoolImpl
void setMaxFreeThreads(size_t value);
void setQueueSize(size_t value);
size_t getMaxThreads() const;
size_t getQueueSize() const;

std::unique_ptr<ThreadPoolWaitGroup<Thread>> waitGroup()
{
Expand Down
194 changes: 194 additions & 0 deletions dbms/src/IO/tests/gtest_io_thread.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include <IO/IOThreadPools.h>
#include <gtest/gtest.h>

#include <exception>
#include <ext/scope_guard.h>
#include <future>
#include <random>

namespace DB::tests
{
namespace
{
using SPtr = std::shared_ptr<std::atomic<int>>;
using WPtr = std::weak_ptr<std::atomic<int>>;

void buildReadTasks(bool throw_in_build_read_tasks, bool throw_in_build_task, WPtr wp, SPtr invalid_count)
{
auto async_build_read_task = [=]() {
return BuildReadTaskPool::get().scheduleWithFuture([=]() {
std::random_device rd;
std::mt19937 gen(rd());
auto sleep_ms = gen() % 100 + 1; // 1~100
std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms));

if (auto sp = wp.lock(); sp)
sp->fetch_add(1);
else
invalid_count->fetch_add(1);

if (throw_in_build_task)
throw Exception("From build_read_task");
});
};

IOPoolHelper::FutureContainer futures(Logger::get("buildReadTasks"));
for (int i = 0; i < 10; i++)
{
futures.add(async_build_read_task());
if (i >= 5 && throw_in_build_read_tasks)
throw Exception("From buildReadTasks");
}
futures.getAllResults();
}

void buildReadTasksForTables(
bool throw_in_build_read_tasks_for_tables,
bool throw_in_build_read_tasks,
bool throw_in_build_task,
WPtr wp,
SPtr invalid_count)
{
auto async_build_read_tasks_for_table = [=]() {
return BuildReadTaskForWNTablePool::get().scheduleWithFuture(
[=]() { buildReadTasks(throw_in_build_read_tasks, throw_in_build_task, wp, invalid_count); });
};

IOPoolHelper::FutureContainer futures(Logger::get("buildReadTasksForTables"));
for (int i = 0; i < 10; i++)
{
futures.add(async_build_read_tasks_for_table());
if (i >= 5 && throw_in_build_read_tasks_for_tables)
throw Exception("From buildReadTasksForTables");
}
futures.getAllResults();
}

void buildReadTasksForWNs(
bool throw_in_build_read_tasks_for_wns,
bool throw_in_build_read_tasks_for_tables,
bool throw_in_build_read_tasks,
bool throw_in_build_task)
{
auto log = Logger::get("buildReadTasksForWNs");
LOG_INFO(
log,
"throw_in_build_read_tasks_for_wns={}, "
"throw_in_build_read_tasks_for_tables={}, throw_in_build_read_tasks={}, throw_in_build_task={}",
throw_in_build_read_tasks_for_wns,
throw_in_build_read_tasks_for_tables,
throw_in_build_read_tasks,
throw_in_build_task);
auto sp = std::make_shared<std::atomic<int>>(0);
auto invalid_count = std::make_shared<std::atomic<int>>(0);
// Use weak_ptr to simulate capture by reference.
auto async_build_tasks_for_wn = [wp = WPtr{sp},
invalid_count,
throw_in_build_read_tasks_for_tables,
throw_in_build_read_tasks,
throw_in_build_task]() {
return BuildReadTaskForWNPool::get().scheduleWithFuture([=]() {
buildReadTasksForTables(
throw_in_build_read_tasks_for_tables,
throw_in_build_read_tasks,
throw_in_build_task,
wp,
invalid_count);
});
};

try
{
IOPoolHelper::FutureContainer futures(log);
SCOPE_EXIT({
if (!throw_in_build_read_tasks_for_wns && !throw_in_build_read_tasks_for_tables
&& !throw_in_build_read_tasks)
ASSERT_EQ(*sp, 10 * 10 * 10);
ASSERT_EQ(*invalid_count, 0);
});
for (int i = 0; i < 10; i++)
{
futures.add(async_build_tasks_for_wn());
if (i >= 5 && throw_in_build_read_tasks_for_wns)
throw Exception("From buildReadTasksForWNs");
}
futures.getAllResults();
}
catch (...)
{
tryLogCurrentException(log);
}
}
} // namespace

TEST(IOThreadPool, TaskChain)
{
constexpr std::array<bool, 2> arr{false, true};
for (auto a : arr)
for (auto b : arr)
for (auto c : arr)
for (auto d : arr)
buildReadTasksForWNs(a, b, c, d);
}

TEST(IOThreadPool, WaitTimeout)
{
auto & thread_pool = BuildReadTaskPool::get();
const auto queue_size = thread_pool.getQueueSize();
std::atomic<bool> stop_flag{false};
IOPoolHelper::FutureContainer futures(Logger::get());
auto loop_until_stop = [&]() {
while (!stop_flag)
std::this_thread::sleep_for(std::chrono::seconds(1));
};
for (size_t i = 0; i < queue_size; ++i)
{
auto f = thread_pool.scheduleWithFuture(loop_until_stop);
futures.add(std::move(f));
}
ASSERT_EQ(thread_pool.active(), queue_size);

auto try_result = thread_pool.trySchedule(loop_until_stop);
ASSERT_FALSE(try_result);

try
{
auto f = thread_pool.scheduleWithFuture(loop_until_stop);
futures.add(std::move(f));
FAIL() << "Should throw exception.";
}
catch (Exception & e)
{
ASSERT_TRUE(e.message().starts_with("Cannot schedule a task: no free thread (timeout=0)"));
}

try
{
auto f = thread_pool.scheduleWithFuture(loop_until_stop, 10000);
futures.add(std::move(f));
FAIL() << "Should throw exception.";
}
catch (Exception & e)
{
ASSERT_TRUE(e.message().starts_with("Cannot schedule a task: no free thread (timeout=10000)"));
}

stop_flag.store(true);
futures.getAllResults();
}
} // namespace DB::tests
7 changes: 7 additions & 0 deletions dbms/src/Storages/StorageDisaggregated.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ class StorageDisaggregated : public IStorage
DAGExpressionAnalyzer & analyzer);
tipb::Executor buildTableScanTiPB();

<<<<<<< HEAD
=======
size_t getBuildTaskRPCTimeout() const;
size_t getBuildTaskIOThreadPoolTimeout() const;

private:
>>>>>>> 738aade7e5 (Disagg: Set the waiting time for thread pool. (#9393))
Context & context;
const TiDBTableScan & table_scan;
LoggerPtr log;
Expand Down
60 changes: 59 additions & 1 deletion dbms/src/Storages/StorageDisaggregatedRemote.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,16 @@ DM::Remote::RNReadTaskPtr StorageDisaggregated::buildReadTask(
auto thread_manager = newThreadManager();
for (const auto & cop_task : batch_cop_tasks)
{
<<<<<<< HEAD
thread_manager->schedule(true, "buildReadTaskForWriteNode", [&] {
buildReadTaskForWriteNode(db_context, scan_context, cop_task, output_lock, output_seg_tasks);
});
=======
auto f = BuildReadTaskForWNPool::get().scheduleWithFuture(
[&] { buildReadTaskForWriteNode(db_context, scan_context, cop_task, output_lock, output_seg_tasks); },
getBuildTaskIOThreadPoolTimeout());
futures.add(std::move(f));
>>>>>>> 738aade7e5 (Disagg: Set the waiting time for thread pool. (#9393))
}

// Let's wait for all threads to finish. Otherwise local variable references will be invalid.
Expand Down Expand Up @@ -221,7 +228,7 @@ void StorageDisaggregated::buildReadTaskForWriteNode(
pingcap::kv::RpcCall<pingcap::kv::RPC_NAME(EstablishDisaggTask)> rpc(cluster->rpc_client, req->address());
disaggregated::EstablishDisaggTaskResponse resp;
grpc::ClientContext client_context;
rpc.setClientContext(client_context, db_context.getSettingsRef().disagg_build_task_timeout);
rpc.setClientContext(client_context, getBuildTaskRPCTimeout());
auto status = rpc.call(&client_context, *req, &resp);
if (status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED)
throw Exception(
Expand Down Expand Up @@ -339,6 +346,7 @@ void StorageDisaggregated::buildReadTaskForWriteNode(
auto thread_manager = newThreadManager();
for (const auto & serialized_physical_table : resp.tables())
{
<<<<<<< HEAD
thread_manager->schedule(true, "buildReadTaskForWriteNodeTable", [&] {
buildReadTaskForWriteNodeTable(
db_context,
Expand All @@ -350,6 +358,22 @@ void StorageDisaggregated::buildReadTaskForWriteNode(
output_lock,
output_seg_tasks);
});
=======
auto f = BuildReadTaskForWNTablePool::get().scheduleWithFuture(
[&] {
buildReadTaskForWriteNodeTable(
db_context,
scan_context,
snapshot_id,
resp.store_id(),
req->address(),
serialized_physical_table,
output_lock,
output_seg_tasks);
},
getBuildTaskIOThreadPoolTimeout());
futures.add(std::move(f));
>>>>>>> 738aade7e5 (Disagg: Set the waiting time for thread pool. (#9393))
}
thread_manager->wait();
}
Expand All @@ -373,6 +397,7 @@ void StorageDisaggregated::buildReadTaskForWriteNodeTable(

auto table_tracing_logger = log->getChild(
fmt::format("store_id={} keyspace={} table_id={}", store_id, table.keyspace_id(), table.table_id()));
<<<<<<< HEAD
for (size_t idx = 0; idx < n; ++idx)
{
const auto & remote_seg = table.segments(idx);
Expand All @@ -392,6 +417,29 @@ void StorageDisaggregated::buildReadTaskForWriteNodeTable(
std::lock_guard lock(output_lock);
output_seg_tasks.push_back(seg_read_task);
});
=======

IOPoolHelper::FutureContainer futures(log, table.segments().size());
for (const auto & remote_seg : table.segments())
{
auto f = BuildReadTaskPool::get().scheduleWithFuture(
[&]() {
auto seg_read_task = std::make_shared<DM::SegmentReadTask>(
table_tracing_logger,
db_context,
scan_context,
remote_seg,
snapshot_id,
store_id,
store_address,
table.keyspace_id(),
table.table_id());
std::lock_guard lock(output_lock);
output_seg_tasks.push_back(seg_read_task);
},
getBuildTaskIOThreadPoolTimeout());
futures.add(std::move(f));
>>>>>>> 738aade7e5 (Disagg: Set the waiting time for thread pool. (#9393))
}

thread_manager->wait();
Expand Down Expand Up @@ -612,4 +660,14 @@ void StorageDisaggregated::buildRemoteSegmentSourceOps(
group_builder.getCurProfileInfos());
}

size_t StorageDisaggregated::getBuildTaskRPCTimeout() const
{
return context.getSettingsRef().disagg_build_task_timeout;
}

size_t StorageDisaggregated::getBuildTaskIOThreadPoolTimeout() const
{
return context.getSettingsRef().disagg_build_task_timeout * 1000000;
}

} // namespace DB

0 comments on commit 70c7aed

Please sign in to comment.