Skip to content

Commit

Permalink
Fix notifications (#716)
Browse files Browse the repository at this point in the history
## What do these changes do?

This PR fixes hang in `AsyncQueue.join()` by replacing `asyncio.Event`
with `asyncio.Condition`. It also:

1. Makes the names of primitives the same style.
2. Reduces notifications in #704 style.
3. Ensures that counters are changed exclusively.

## Are there changes in behavior for the user?

There are no behavior changes for users.

## Related issue number

Fixes #715
  • Loading branch information
x42005e1f authored Dec 11, 2024
1 parent f90a38a commit a85cc40
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 65 deletions.
4 changes: 3 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Changes

- Optimize internal implementation for a little speedup #699

- Make not-full and not-empty notificatios faster #703
- Make not-full and not-empty notifications faster #703

- Add ``.aclose()`` async method #709

Expand All @@ -16,6 +16,8 @@ Changes

- Remove sync notifiers for a major speedup #714

- Fix hang in ``AsyncQueue.join()`` #716

1.1.0 (2024-10-30)
------------------

Expand Down
140 changes: 76 additions & 64 deletions janus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def __init__(self, maxsize: int = 0) -> None:
self._sync_not_empty_waiting = 0
self._sync_not_full = threading.Condition(self._sync_mutex)
self._sync_not_full_waiting = 0
self._all_tasks_done = threading.Condition(self._sync_mutex)
self._sync_tasks_done = threading.Condition(self._sync_mutex)
self._sync_tasks_done_waiting = 0

self._async_mutex = asyncio.Lock()
if sys.version_info[:3] == (3, 10, 0):
Expand All @@ -104,8 +105,8 @@ def __init__(self, maxsize: int = 0) -> None:
self._async_not_empty_waiting = 0
self._async_not_full = asyncio.Condition(self._async_mutex)
self._async_not_full_waiting = 0
self._finished = asyncio.Event()
self._finished.set()
self._async_tasks_done = asyncio.Condition(self._async_mutex)
self._async_tasks_done_waiting = 0

self._closing = False
self._pending: deque[asyncio.Future[Any]] = deque()
Expand Down Expand Up @@ -142,8 +143,13 @@ def close(self) -> None:
self._closing = True
for fut in self._pending:
fut.cancel()
self._finished.set() # unblocks all async_q.join()
self._all_tasks_done.notify_all() # unblocks all sync_q.join()
if self._async_tasks_done_waiting:
if self._loop is not None:
self._call_soon_threadsafe( # unblocks all async_q.join()
self._make_async_tasks_done_notifier, self._loop
)
if self._sync_tasks_done_waiting:
self._sync_tasks_done.notify_all() # unblocks all sync_q.join()

async def wait_closed(self) -> None:
# should be called from loop after close().
Expand Down Expand Up @@ -201,7 +207,6 @@ def _get(self) -> T:
def _put_internal(self, item: T) -> None:
self._put(item)
self._unfinished_tasks += 1
self._finished.clear()

async def _async_not_empty_notifier(self) -> None:
async with self._async_mutex:
Expand All @@ -221,6 +226,15 @@ def _make_async_not_full_notifier(self, loop: asyncio.AbstractEventLoop) -> None
task.add_done_callback(self._pending.remove)
self._pending.append(task)

async def _async_tasks_done_notifier(self) -> None:
async with self._async_mutex:
self._async_tasks_done.notify_all()

def _make_async_tasks_done_notifier(self, loop: asyncio.AbstractEventLoop) -> None:
task = loop.create_task(self._async_tasks_done_notifier())
task.add_done_callback(self._pending.remove)
self._pending.append(task)

def _check_closing(self) -> None:
if self._closing:
raise RuntimeError("Operation on the closed queue is forbidden")
Expand Down Expand Up @@ -259,13 +273,18 @@ def task_done(self) -> None:
"""
parent = self._parent
parent._check_closing()
with parent._all_tasks_done:
with parent._sync_tasks_done:
unfinished = parent._unfinished_tasks - 1
if unfinished <= 0:
if unfinished < 0:
raise ValueError("task_done() called too many times")
parent._all_tasks_done.notify_all()
parent._call_soon_threadsafe(parent._finished.set)
if parent._sync_tasks_done_waiting:
parent._sync_tasks_done.notify_all()
if parent._async_tasks_done_waiting:
if parent._loop is not None:
parent._call_soon_threadsafe(
parent._make_async_tasks_done_notifier, parent._loop
)
parent._unfinished_tasks = unfinished

def join(self) -> None:
Expand All @@ -279,9 +298,13 @@ def join(self) -> None:
"""
parent = self._parent
parent._check_closing()
with parent._all_tasks_done:
with parent._sync_tasks_done:
while parent._unfinished_tasks:
parent._all_tasks_done.wait()
parent._sync_tasks_done_waiting += 1
try:
parent._sync_tasks_done.wait()
finally:
parent._sync_tasks_done_waiting -= 1
parent._check_closing()

def qsize(self) -> int:
Expand Down Expand Up @@ -486,33 +509,22 @@ async def put(self, item: T) -> None:
parent = self._parent
parent._check_closing()
async with parent._async_not_full:
parent._sync_mutex.acquire()
locked = True
parent._get_loop() # check the event loop
try:
if parent._maxsize > 0:
do_wait = True
while do_wait:
do_wait = parent._qsize() >= parent._maxsize
if do_wait:
locked = False
parent._sync_mutex.release()
parent._async_not_full_waiting += 1
try:
await parent._async_not_full.wait()
finally:
parent._async_not_full_waiting -= 1
parent._sync_mutex.acquire()
locked = True
with parent._sync_mutex:
parent._get_loop() # check the event loop
while 0 < parent._maxsize <= parent._qsize():
parent._async_not_full_waiting += 1
parent._sync_mutex.release()
try:
await parent._async_not_full.wait()
finally:
parent._sync_mutex.acquire()
parent._async_not_full_waiting -= 1

parent._put_internal(item)
if parent._async_not_empty_waiting:
parent._async_not_empty.notify()
if parent._sync_not_empty_waiting:
parent._sync_not_empty.notify()
finally:
if locked:
parent._sync_mutex.release()

def put_nowait(self, item: T) -> None:
"""Put an item into the queue without blocking.
Expand All @@ -523,9 +535,8 @@ def put_nowait(self, item: T) -> None:
parent._check_closing()
with parent._sync_mutex:
loop = parent._get_loop()
if parent._maxsize > 0:
if parent._qsize() >= parent._maxsize:
raise AsyncQueueFull
if 0 < parent._maxsize <= parent._qsize():
raise AsyncQueueFull

parent._put_internal(item)
if parent._async_not_empty_waiting:
Expand All @@ -543,34 +554,23 @@ async def get(self) -> T:
parent = self._parent
parent._check_closing()
async with parent._async_not_empty:
parent._sync_mutex.acquire()
locked = True
parent._get_loop() # check the event loop
try:
do_wait = True
while do_wait:
do_wait = parent._qsize() == 0

if do_wait:
locked = False
parent._sync_mutex.release()
parent._async_not_empty_waiting += 1
try:
await parent._async_not_empty.wait()
finally:
parent._async_not_empty_waiting -= 1
with parent._sync_mutex:
parent._get_loop() # check the event loop
while not parent._qsize():
parent._async_not_empty_waiting += 1
parent._sync_mutex.release()
try:
await parent._async_not_empty.wait()
finally:
parent._sync_mutex.acquire()
locked = True
parent._async_not_empty_waiting -= 1

item = parent._get()
if parent._async_not_full_waiting:
parent._async_not_full.notify()
if parent._sync_not_full_waiting:
parent._sync_not_full.notify()
return item
finally:
if locked:
parent._sync_mutex.release()

def get_nowait(self) -> T:
"""Remove and return an item from the queue.
Expand All @@ -580,7 +580,7 @@ def get_nowait(self) -> T:
parent = self._parent
parent._check_closing()
with parent._sync_mutex:
if parent._qsize() == 0:
if not parent._qsize():
raise AsyncQueueEmpty

loop = parent._get_loop()
Expand Down Expand Up @@ -608,13 +608,18 @@ def task_done(self) -> None:
"""
parent = self._parent
parent._check_closing()
with parent._all_tasks_done:
with parent._sync_tasks_done:
if parent._unfinished_tasks <= 0:
raise ValueError("task_done() called too many times")
parent._unfinished_tasks -= 1
if parent._unfinished_tasks == 0:
parent._finished.set()
parent._all_tasks_done.notify_all()
if parent._async_tasks_done_waiting:
if parent._loop is not None:
parent._call_soon_threadsafe(
parent._make_async_tasks_done_notifier, parent._loop
)
if parent._sync_tasks_done_waiting:
parent._sync_tasks_done.notify_all()

async def join(self) -> None:
"""Block until all items in the queue have been gotten and processed.
Expand All @@ -625,12 +630,19 @@ async def join(self) -> None:
When the count of unfinished tasks drops to zero, join() unblocks.
"""
parent = self._parent
while True:
parent._check_closing()
async with parent._async_tasks_done:
with parent._sync_mutex:
parent._check_closing()
if parent._unfinished_tasks == 0:
break
await parent._finished.wait()
parent._get_loop() # check the event loop
while parent._unfinished_tasks:
parent._async_tasks_done_waiting += 1
parent._sync_mutex.release()
try:
await parent._async_tasks_done.wait()
finally:
parent._sync_mutex.acquire()
parent._async_tasks_done_waiting -= 1
parent._check_closing()


class PriorityQueue(Queue[T]):
Expand Down

0 comments on commit a85cc40

Please sign in to comment.