Skip to content

Commit

Permalink
Disagg: Set the waiting time for thread pool. (#9393)
Browse files Browse the repository at this point in the history
close #9392

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
JinheLin and ti-chi-bot[bot] authored Sep 2, 2024
1 parent 0a80ae0 commit 738aade
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 17 deletions.
9 changes: 8 additions & 1 deletion dbms/src/Common/UniThreadPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ void ThreadPoolImpl<Thread>::setQueueSize(size_t value)
jobs.reserve(queue_size);
}

template <typename Thread>
size_t ThreadPoolImpl<Thread>::getQueueSize() const
{
std::lock_guard lock(mutex);
return queue_size;
}


template <typename Thread>
template <typename ReturnType>
Expand Down Expand Up @@ -202,7 +209,7 @@ template <typename Thread>
std::future<void> ThreadPoolImpl<Thread>::scheduleWithFuture(Job job, uint64_t wait_timeout_us)
{
auto task = std::make_shared<std::packaged_task<void()>>(std::move(job));
scheduleOrThrow([task]() { (*task)(); }, 0, wait_timeout_us);
scheduleImpl<void>([task]() { (*task)(); }, /*priority*/ 0, wait_timeout_us);
return task->get_future();
}

Expand Down
6 changes: 5 additions & 1 deletion dbms/src/Common/UniThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,19 @@ class ThreadPoolImpl
void scheduleOrThrowOnError(Job job, ssize_t priority = 0);

/// Similar to scheduleOrThrowOnError(...). Wait for specified amount of time and schedule a job or return false.
/// If wait_microseconds is zero, it means never wait.
bool trySchedule(Job job, ssize_t priority = 0, uint64_t wait_microseconds = 0) noexcept;

/// Similar to scheduleOrThrowOnError(...). Wait for specified amount of time and schedule a job or throw an exception.
/// If wait_microseconds is zero, it means never wait.
void scheduleOrThrow(
Job job,
ssize_t priority = 0,
uint64_t wait_microseconds = 0,
bool propagate_opentelemetry_tracing_context = true);

/// Wrap job with std::packaged_task<void> and returns a std::future<void> object to check result of the job.
/// Wrap job with std::packaged_task<void> and returns a std::future<void> object to check if the task has finished or thrown an exception.
/// If wait_microseconds is zero, it means never wait.
std::future<void> scheduleWithFuture(Job job, uint64_t wait_timeout_us = 0);

/// Wait for all currently active jobs to be done.
Expand All @@ -107,6 +110,7 @@ class ThreadPoolImpl
void setMaxFreeThreads(size_t value);
void setQueueSize(size_t value);
size_t getMaxThreads() const;
size_t getQueueSize() const;

std::unique_ptr<ThreadPoolWaitGroup<Thread>> waitGroup()
{
Expand Down
45 changes: 45 additions & 0 deletions dbms/src/IO/tests/gtest_io_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,49 @@ TEST(IOThreadPool, TaskChain)
buildReadTasksForWNs(a, b, c, d);
}

TEST(IOThreadPool, WaitTimeout)
{
auto & thread_pool = BuildReadTaskPool::get();
const auto queue_size = thread_pool.getQueueSize();
std::atomic<bool> stop_flag{false};
IOPoolHelper::FutureContainer futures(Logger::get());
auto loop_until_stop = [&]() {
while (!stop_flag)
std::this_thread::sleep_for(std::chrono::seconds(1));
};
for (size_t i = 0; i < queue_size; ++i)
{
auto f = thread_pool.scheduleWithFuture(loop_until_stop);
futures.add(std::move(f));
}
ASSERT_EQ(thread_pool.active(), queue_size);

auto try_result = thread_pool.trySchedule(loop_until_stop);
ASSERT_FALSE(try_result);

try
{
auto f = thread_pool.scheduleWithFuture(loop_until_stop);
futures.add(std::move(f));
FAIL() << "Should throw exception.";
}
catch (Exception & e)
{
ASSERT_TRUE(e.message().starts_with("Cannot schedule a task: no free thread (timeout=0)"));
}

try
{
auto f = thread_pool.scheduleWithFuture(loop_until_stop, 10000);
futures.add(std::move(f));
FAIL() << "Should throw exception.";
}
catch (Exception & e)
{
ASSERT_TRUE(e.message().starts_with("Cannot schedule a task: no free thread (timeout=10000)"));
}

stop_flag.store(true);
futures.getAllResults();
}
} // namespace DB::tests
3 changes: 3 additions & 0 deletions dbms/src/Storages/StorageDisaggregated.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ class StorageDisaggregated : public IStorage
DAGExpressionAnalyzer & analyzer);
tipb::Executor buildTableScanTiPB();

size_t getBuildTaskRPCTimeout() const;
size_t getBuildTaskIOThreadPoolTimeout() const;

private:
Context & context;
const TiDBTableScan & table_scan;
Expand Down
42 changes: 27 additions & 15 deletions dbms/src/Storages/StorageDisaggregatedRemote.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ DM::SegmentReadTasks StorageDisaggregated::buildReadTask(
for (const auto & cop_task : batch_cop_tasks)
{
auto f = BuildReadTaskForWNPool::get().scheduleWithFuture(
[&] { buildReadTaskForWriteNode(db_context, scan_context, cop_task, output_lock, output_seg_tasks); });
[&] { buildReadTaskForWriteNode(db_context, scan_context, cop_task, output_lock, output_seg_tasks); },
getBuildTaskIOThreadPoolTimeout());
futures.add(std::move(f));
}
futures.getAllResults();
Expand Down Expand Up @@ -246,7 +247,7 @@ void StorageDisaggregated::buildReadTaskForWriteNode(
pingcap::kv::RpcCall<pingcap::kv::RPC_NAME(EstablishDisaggTask)> rpc(cluster->rpc_client, req->address());
disaggregated::EstablishDisaggTaskResponse resp;
grpc::ClientContext client_context;
rpc.setClientContext(client_context, db_context.getSettingsRef().disagg_build_task_timeout);
rpc.setClientContext(client_context, getBuildTaskRPCTimeout());
auto status = rpc.call(&client_context, *req, &resp);
if (status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED)
throw Exception(
Expand Down Expand Up @@ -364,17 +365,19 @@ void StorageDisaggregated::buildReadTaskForWriteNode(
IOPoolHelper::FutureContainer futures(log, resp.tables().size());
for (const auto & serialized_physical_table : resp.tables())
{
auto f = BuildReadTaskForWNTablePool::get().scheduleWithFuture([&] {
buildReadTaskForWriteNodeTable(
db_context,
scan_context,
snapshot_id,
resp.store_id(),
req->address(),
serialized_physical_table,
output_lock,
output_seg_tasks);
});
auto f = BuildReadTaskForWNTablePool::get().scheduleWithFuture(
[&] {
buildReadTaskForWriteNodeTable(
db_context,
scan_context,
snapshot_id,
resp.store_id(),
req->address(),
serialized_physical_table,
output_lock,
output_seg_tasks);
},
getBuildTaskIOThreadPoolTimeout());
futures.add(std::move(f));
}
futures.getAllResults();
Expand All @@ -395,7 +398,6 @@ void StorageDisaggregated::buildReadTaskForWriteNodeTable(
RUNTIME_CHECK_MSG(parse_ok, "Failed to deserialize RemotePhysicalTable from response");
auto table_tracing_logger = log->getChild(
fmt::format("store_id={} keyspace={} table_id={}", store_id, table.keyspace_id(), table.table_id()));
auto disagg_build_task_timeout_us = db_context.getSettingsRef().disagg_build_task_timeout * 1000000;

IOPoolHelper::FutureContainer futures(log, table.segments().size());
for (const auto & remote_seg : table.segments())
Expand All @@ -415,7 +417,7 @@ void StorageDisaggregated::buildReadTaskForWriteNodeTable(
std::lock_guard lock(output_lock);
output_seg_tasks.push_back(seg_read_task);
},
disagg_build_task_timeout_us);
getBuildTaskIOThreadPoolTimeout());
futures.add(std::move(f));
}
futures.getAllResults();
Expand Down Expand Up @@ -690,4 +692,14 @@ void StorageDisaggregated::buildRemoteSegmentSourceOps(
group_builder.getCurProfileInfos());
}

size_t StorageDisaggregated::getBuildTaskRPCTimeout() const
{
return context.getSettingsRef().disagg_build_task_timeout;
}

size_t StorageDisaggregated::getBuildTaskIOThreadPoolTimeout() const
{
return context.getSettingsRef().disagg_build_task_timeout * 1000000;
}

} // namespace DB

0 comments on commit 738aade

Please sign in to comment.