Skip to content

Commit

Permalink
Common: Add TaskQueue class
Browse files Browse the repository at this point in the history
  • Loading branch information
stenzek committed Jan 3, 2025
1 parent 52e6e8f commit 5476015
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ add_library(common
thirdparty/SmallVector.h
thirdparty/aes.cpp
thirdparty/aes.h
task_queue.cpp
task_queue.h
threading.cpp
threading.h
timer.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/common/common.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
<ClInclude Include="thirdparty\SmallVector.h" />
<ClInclude Include="thirdparty\StackWalker.h" />
<ClInclude Include="threading.h" />
<ClInclude Include="task_queue.h" />
<ClInclude Include="timer.h" />
<ClInclude Include="types.h" />
<ClInclude Include="minizip_helpers.h" />
Expand Down Expand Up @@ -74,6 +75,7 @@
<ClCompile Include="thirdparty\SmallVector.cpp" />
<ClCompile Include="thirdparty\StackWalker.cpp" />
<ClCompile Include="threading.cpp" />
<ClCompile Include="task_queue.cpp" />
<ClCompile Include="timer.cpp" />
</ItemGroup>
<ItemGroup>
Expand Down
2 changes: 2 additions & 0 deletions src/common/common.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
<ClInclude Include="log_channels.h" />
<ClInclude Include="sha256_digest.h" />
<ClInclude Include="thirdparty\aes.h" />
<ClInclude Include="task_queue.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="small_string.cpp" />
Expand Down Expand Up @@ -82,6 +83,7 @@
<ClCompile Include="gsvector.cpp" />
<ClCompile Include="sha256_digest.cpp" />
<ClCompile Include="thirdparty\aes.cpp" />
<ClCompile Include="task_queue.cpp" />
</ItemGroup>
<ItemGroup>
<Natvis Include="bitfield.natvis" />
Expand Down
98 changes: 98 additions & 0 deletions src/common/task_queue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// SPDX-FileCopyrightText: 2019-2025 Connor McLaughlin <[email protected]>
// SPDX-License-Identifier: CC-BY-NC-ND-4.0

#include "task_queue.h"
#include "assert.h"

TaskQueue::TaskQueue() = default;

TaskQueue::~TaskQueue()
{
SetWorkerCount(0);
Assert(m_tasks.empty());
}

void TaskQueue::SetWorkerCount(u32 count)
{
std::unique_lock lock(m_mutex);

WaitForAll(lock);

if (!m_threads.empty())
{
m_threads_done = true;
m_task_wait_cv.notify_all();

auto threads = std::move(m_threads);
m_threads = decltype(threads)();

lock.unlock();
for (std::thread& t : threads)
t.join();
lock.lock();
}

if (count > 0)
{
m_threads_done = false;
for (u32 i = 0; i < count; i++)
m_threads.emplace_back(&TaskQueue::WorkerThreadEntryPoint, this);
}
}

void TaskQueue::SubmitTask(TaskFunctionType func)
{
std::unique_lock lock(m_mutex);
m_tasks.push_back(std::move(func));
m_tasks_outstanding++;
m_task_wait_cv.notify_one();
}

void TaskQueue::WaitForAll()
{
std::unique_lock lock(m_mutex);
WaitForAll(lock);
}

void TaskQueue::WaitForAll(std::unique_lock<std::mutex>& lock)
{
// while we're waiting, execute work on the calling thread
m_tasks_done_cv.wait(lock, [this, &lock]() {
if (m_tasks_outstanding == 0)
return true;

while (!m_tasks.empty())
ExecuteOneTask(lock);

return (m_tasks_outstanding == 0);
});
}

void TaskQueue::ExecuteOneTask(std::unique_lock<std::mutex>& lock)
{
TaskFunctionType func = std::move(m_tasks.front());
m_tasks.pop_front();
lock.unlock();
func();
lock.lock();
m_tasks_outstanding--;
if (m_tasks_outstanding == 0)
m_tasks_done_cv.notify_all();
}

void TaskQueue::WorkerThreadEntryPoint()
{
Threading::SetNameOfCurrentThread("TaskQueue Worker");

std::unique_lock lock(m_mutex);
while (!m_threads_done)
{
if (m_tasks.empty())
{
m_task_wait_cv.wait(lock);
continue;
}

ExecuteOneTask(lock);
}
}
59 changes: 59 additions & 0 deletions src/common/task_queue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-FileCopyrightText: 2019-2025 Connor McLaughlin <[email protected]>
// SPDX-License-Identifier: CC-BY-NC-ND-4.0

#pragma once

#include "threading.h"
#include "types.h"

#include <atomic>
#include <condition_variable>
#include <deque>
#include <functional>
#include <mutex>
#include <thread>
#include <vector>

/// Implements a simple task queue with multiple worker threads.
class TaskQueue
{
public:
using TaskFunctionType = std::function<void()>;

TaskQueue();
~TaskQueue();

/// Sets the number of worker threads to be used by the task queue.
/// Setting this to zero threads completes tasks on the calling thread.
/// @param count The desired number of worker threads.
void SetWorkerCount(u32 count);

/// Submits a task to the queue for execution.
/// @param func The task function to execute.
void SubmitTask(TaskFunctionType func);

/// Waits for all submitted tasks to complete execution.
void WaitForAll();

private:
/// Waits for all submitted tasks to complete execution.
/// This is a helper function that assumes a lock is already held.
/// @param lock A unique_lock object holding the mutex.
void WaitForAll(std::unique_lock<std::mutex>& lock);

/// Executes one task from the queue.
/// This is a helper function that assumes a lock is already held.
/// @param lock A unique_lock object holding the mutex.
void ExecuteOneTask(std::unique_lock<std::mutex>& lock);

/// Entry point for worker threads. Executes tasks from the queue until termination is signaled.
void WorkerThreadEntryPoint();

std::mutex m_mutex;
std::deque<TaskFunctionType> m_tasks;
size_t m_tasks_outstanding = 0;
std::condition_variable m_task_wait_cv;
std::condition_variable m_tasks_done_cv;
std::vector<std::thread> m_threads;
bool m_threads_done = false;
};

0 comments on commit 5476015

Please sign in to comment.