diff --git a/CHANGES.rst b/CHANGES.rst index 7f700ab..7df434a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,7 +6,7 @@ Changes - Optimize internal implementation for a little speedup #699 -- Make not-full and not-empty notificatios faster #703 +- Make not-full and not-empty notifications faster #703 - Add ``.aclose()`` async method #709 @@ -16,6 +16,8 @@ Changes - Remove sync notifiers for a major speedup #714 +- Fix hang in ``AsyncQueue.join()`` #716 + 1.1.0 (2024-10-30) ------------------ diff --git a/janus/__init__.py b/janus/__init__.py index 23c4076..1f7db59 100644 --- a/janus/__init__.py +++ b/janus/__init__.py @@ -94,7 +94,8 @@ def __init__(self, maxsize: int = 0) -> None: self._sync_not_empty_waiting = 0 self._sync_not_full = threading.Condition(self._sync_mutex) self._sync_not_full_waiting = 0 - self._all_tasks_done = threading.Condition(self._sync_mutex) + self._sync_tasks_done = threading.Condition(self._sync_mutex) + self._sync_tasks_done_waiting = 0 self._async_mutex = asyncio.Lock() if sys.version_info[:3] == (3, 10, 0): @@ -104,8 +105,8 @@ def __init__(self, maxsize: int = 0) -> None: self._async_not_empty_waiting = 0 self._async_not_full = asyncio.Condition(self._async_mutex) self._async_not_full_waiting = 0 - self._finished = asyncio.Event() - self._finished.set() + self._async_tasks_done = asyncio.Condition(self._async_mutex) + self._async_tasks_done_waiting = 0 self._closing = False self._pending: deque[asyncio.Future[Any]] = deque() @@ -142,8 +143,13 @@ def close(self) -> None: self._closing = True for fut in self._pending: fut.cancel() - self._finished.set() # unblocks all async_q.join() - self._all_tasks_done.notify_all() # unblocks all sync_q.join() + if self._async_tasks_done_waiting: + if self._loop is not None: + self._call_soon_threadsafe( # unblocks all async_q.join() + self._make_async_tasks_done_notifier, self._loop + ) + if self._sync_tasks_done_waiting: + self._sync_tasks_done.notify_all() # unblocks all sync_q.join() async def wait_closed(self) -> None: # should be called from loop after close(). @@ -201,7 +207,6 @@ def _get(self) -> T: def _put_internal(self, item: T) -> None: self._put(item) self._unfinished_tasks += 1 - self._finished.clear() async def _async_not_empty_notifier(self) -> None: async with self._async_mutex: @@ -221,6 +226,15 @@ def _make_async_not_full_notifier(self, loop: asyncio.AbstractEventLoop) -> None task.add_done_callback(self._pending.remove) self._pending.append(task) + async def _async_tasks_done_notifier(self) -> None: + async with self._async_mutex: + self._async_tasks_done.notify_all() + + def _make_async_tasks_done_notifier(self, loop: asyncio.AbstractEventLoop) -> None: + task = loop.create_task(self._async_tasks_done_notifier()) + task.add_done_callback(self._pending.remove) + self._pending.append(task) + def _check_closing(self) -> None: if self._closing: raise RuntimeError("Operation on the closed queue is forbidden") @@ -259,13 +273,18 @@ def task_done(self) -> None: """ parent = self._parent parent._check_closing() - with parent._all_tasks_done: + with parent._sync_tasks_done: unfinished = parent._unfinished_tasks - 1 if unfinished <= 0: if unfinished < 0: raise ValueError("task_done() called too many times") - parent._all_tasks_done.notify_all() - parent._call_soon_threadsafe(parent._finished.set) + if parent._sync_tasks_done_waiting: + parent._sync_tasks_done.notify_all() + if parent._async_tasks_done_waiting: + if parent._loop is not None: + parent._call_soon_threadsafe( + parent._make_async_tasks_done_notifier, parent._loop + ) parent._unfinished_tasks = unfinished def join(self) -> None: @@ -279,9 +298,13 @@ def join(self) -> None: """ parent = self._parent parent._check_closing() - with parent._all_tasks_done: + with parent._sync_tasks_done: while parent._unfinished_tasks: - parent._all_tasks_done.wait() + parent._sync_tasks_done_waiting += 1 + try: + parent._sync_tasks_done.wait() + finally: + parent._sync_tasks_done_waiting -= 1 parent._check_closing() def qsize(self) -> int: @@ -486,33 +509,22 @@ async def put(self, item: T) -> None: parent = self._parent parent._check_closing() async with parent._async_not_full: - parent._sync_mutex.acquire() - locked = True - parent._get_loop() # check the event loop - try: - if parent._maxsize > 0: - do_wait = True - while do_wait: - do_wait = parent._qsize() >= parent._maxsize - if do_wait: - locked = False - parent._sync_mutex.release() - parent._async_not_full_waiting += 1 - try: - await parent._async_not_full.wait() - finally: - parent._async_not_full_waiting -= 1 - parent._sync_mutex.acquire() - locked = True + with parent._sync_mutex: + parent._get_loop() # check the event loop + while 0 < parent._maxsize <= parent._qsize(): + parent._async_not_full_waiting += 1 + parent._sync_mutex.release() + try: + await parent._async_not_full.wait() + finally: + parent._sync_mutex.acquire() + parent._async_not_full_waiting -= 1 parent._put_internal(item) if parent._async_not_empty_waiting: parent._async_not_empty.notify() if parent._sync_not_empty_waiting: parent._sync_not_empty.notify() - finally: - if locked: - parent._sync_mutex.release() def put_nowait(self, item: T) -> None: """Put an item into the queue without blocking. @@ -523,9 +535,8 @@ def put_nowait(self, item: T) -> None: parent._check_closing() with parent._sync_mutex: loop = parent._get_loop() - if parent._maxsize > 0: - if parent._qsize() >= parent._maxsize: - raise AsyncQueueFull + if 0 < parent._maxsize <= parent._qsize(): + raise AsyncQueueFull parent._put_internal(item) if parent._async_not_empty_waiting: @@ -543,24 +554,16 @@ async def get(self) -> T: parent = self._parent parent._check_closing() async with parent._async_not_empty: - parent._sync_mutex.acquire() - locked = True - parent._get_loop() # check the event loop - try: - do_wait = True - while do_wait: - do_wait = parent._qsize() == 0 - - if do_wait: - locked = False - parent._sync_mutex.release() - parent._async_not_empty_waiting += 1 - try: - await parent._async_not_empty.wait() - finally: - parent._async_not_empty_waiting -= 1 + with parent._sync_mutex: + parent._get_loop() # check the event loop + while not parent._qsize(): + parent._async_not_empty_waiting += 1 + parent._sync_mutex.release() + try: + await parent._async_not_empty.wait() + finally: parent._sync_mutex.acquire() - locked = True + parent._async_not_empty_waiting -= 1 item = parent._get() if parent._async_not_full_waiting: @@ -568,9 +571,6 @@ async def get(self) -> T: if parent._sync_not_full_waiting: parent._sync_not_full.notify() return item - finally: - if locked: - parent._sync_mutex.release() def get_nowait(self) -> T: """Remove and return an item from the queue. @@ -580,7 +580,7 @@ def get_nowait(self) -> T: parent = self._parent parent._check_closing() with parent._sync_mutex: - if parent._qsize() == 0: + if not parent._qsize(): raise AsyncQueueEmpty loop = parent._get_loop() @@ -608,13 +608,18 @@ def task_done(self) -> None: """ parent = self._parent parent._check_closing() - with parent._all_tasks_done: + with parent._sync_tasks_done: if parent._unfinished_tasks <= 0: raise ValueError("task_done() called too many times") parent._unfinished_tasks -= 1 if parent._unfinished_tasks == 0: - parent._finished.set() - parent._all_tasks_done.notify_all() + if parent._async_tasks_done_waiting: + if parent._loop is not None: + parent._call_soon_threadsafe( + parent._make_async_tasks_done_notifier, parent._loop + ) + if parent._sync_tasks_done_waiting: + parent._sync_tasks_done.notify_all() async def join(self) -> None: """Block until all items in the queue have been gotten and processed. @@ -625,12 +630,19 @@ async def join(self) -> None: When the count of unfinished tasks drops to zero, join() unblocks. """ parent = self._parent - while True: + parent._check_closing() + async with parent._async_tasks_done: with parent._sync_mutex: - parent._check_closing() - if parent._unfinished_tasks == 0: - break - await parent._finished.wait() + parent._get_loop() # check the event loop + while parent._unfinished_tasks: + parent._async_tasks_done_waiting += 1 + parent._sync_mutex.release() + try: + await parent._async_tasks_done.wait() + finally: + parent._sync_mutex.acquire() + parent._async_tasks_done_waiting -= 1 + parent._check_closing() class PriorityQueue(Queue[T]):