Skip to content

Commit

Permalink
Use a traditional semaphore in AsyncioRunnable (#412)
Browse files Browse the repository at this point in the history
Closes nv-morpheus/Morpheus#1339

Replaces the ClosableRingBuffer usage in AsyncioRunnable to instead use a traditional semaphore which seems to be more reliable for this use case.

Authors:
  - Christopher Harris (https://github.com/cwharris)

Approvers:
  - David Gardner (https://github.com/dagardner-nv)

URL: #412
  • Loading branch information
cwharris authored Nov 3, 2023
1 parent 62e1834 commit 8aa9216
Showing 1 changed file with 9 additions and 19 deletions.
28 changes: 9 additions & 19 deletions python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ class AsyncioRunnable : public AsyncSink<InputT>,
using task_buffer_t = mrc::coroutines::ClosableRingBuffer<size_t>;

public:
AsyncioRunnable(size_t concurrency = 8) : m_concurrency(concurrency){};
~AsyncioRunnable() override = default;

private:
Expand All @@ -199,7 +198,6 @@ class AsyncioRunnable : public AsyncSink<InputT>,
* @brief The per-value coroutine run asynchronously alongside other calls.
*/
coroutines::Task<> process_one(InputT value,
task_buffer_t& task_buffer,
std::shared_ptr<mrc::coroutines::Scheduler> on,
ExceptionCatcher& catcher);

Expand All @@ -211,7 +209,11 @@ class AsyncioRunnable : public AsyncSink<InputT>,

std::stop_source m_stop_source;

size_t m_concurrency{8};
/**
* @brief A semaphore used to control the number of outstanding operations. Acquire one before
* beginning a task, and release it when finished.
*/
std::counting_semaphore<8> m_task_tickets{8};
};

template <typename InputT, typename OutputT>
Expand Down Expand Up @@ -279,15 +281,14 @@ void AsyncioRunnable<InputT, OutputT>::run(mrc::runnable::Context& ctx)
template <typename InputT, typename OutputT>
coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<mrc::coroutines::Scheduler> scheduler)
{
// Create the task buffer to limit the number of running tasks
task_buffer_t task_buffer{{.capacity = m_concurrency}};

coroutines::TaskContainer outstanding_tasks(scheduler);

ExceptionCatcher catcher{};

while (not m_stop_source.stop_requested() and not catcher.has_exception())
{
m_task_tickets.acquire();

InputT data;

auto read_status = co_await this->read_async(data);
Expand All @@ -297,26 +298,16 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m
break;
}

// Wait for an available slot in the task buffer
co_await task_buffer.write(0);

outstanding_tasks.start(this->process_one(std::move(data), task_buffer, scheduler, catcher));
outstanding_tasks.start(this->process_one(std::move(data), scheduler, catcher));
}

// Close the buffer
task_buffer.close();

// Now block until all tasks are complete
co_await task_buffer.completed();

co_await outstanding_tasks.garbage_collect_and_yield_until_empty();

catcher.rethrow_next_exception();
}

template <typename InputT, typename OutputT>
coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT value,
task_buffer_t& task_buffer,
std::shared_ptr<mrc::coroutines::Scheduler> on,
ExceptionCatcher& catcher)
{
Expand Down Expand Up @@ -344,8 +335,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT value,
catcher.push_exception(std::current_exception());
}

// Return the slot to the task buffer
co_await task_buffer.read();
m_task_tickets.release();
}

template <typename InputT, typename OutputT>
Expand Down

0 comments on commit 8aa9216

Please sign in to comment.