diff --git a/CHANGES.rst b/CHANGES.rst index c8d60c2..648ef02 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -12,6 +12,8 @@ Changes - Reduce notifications for a minor speedup #704 +- Allow ``janus.Queue()`` instantiation without running asyncio event loop #710 + 1.1.0 (2024-10-30) ------------------ diff --git a/janus/__init__.py b/janus/__init__.py index cbfa46f..925dc2f 100644 --- a/janus/__init__.py +++ b/janus/__init__.py @@ -7,6 +7,7 @@ from heapq import heappop, heappush from queue import Empty as SyncQueueEmpty from queue import Full as SyncQueueFull +from time import monotonic from typing import Any, Callable, Generic, Optional, Protocol, TypeVar __version__ = "1.1.0" @@ -24,7 +25,14 @@ ) +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + + T = TypeVar("T") +P = ParamSpec("P") OptFloat = Optional[float] @@ -69,8 +77,12 @@ async def join(self) -> None: ... class Queue(Generic[T]): + _loop: Optional[asyncio.AbstractEventLoop] = None + def __init__(self, maxsize: int = 0) -> None: - self._loop = asyncio.get_running_loop() + if sys.version_info < (3, 10): + self._loop = asyncio.get_running_loop() + self._maxsize = maxsize self._init(maxsize) @@ -98,26 +110,33 @@ def __init__(self, maxsize: int = 0) -> None: self._closing = False self._pending: deque[asyncio.Future[Any]] = deque() - def checked_call_soon_threadsafe( - callback: Callable[..., None], *args: Any - ) -> None: - try: - self._loop.call_soon_threadsafe(callback, *args) - except RuntimeError: - # swallowing agreed in #2 - pass - - self._call_soon_threadsafe = checked_call_soon_threadsafe - - def checked_call_soon(callback: Callable[..., None], *args: Any) -> None: - if not self._loop.is_closed(): - self._loop.call_soon(callback, *args) - - self._call_soon = checked_call_soon - self._sync_queue = _SyncQueueProxy(self) self._async_queue = _AsyncQueueProxy(self) + def _call_soon_threadsafe( + self, callback: Callable[P, None], *args: P.args, **kwargs: P.kwargs + ) -> None: + if self._loop is None: + # async API didn't accessed yet, nothing to notify + return + try: + self._loop.call_soon_threadsafe(callback, *args) + except RuntimeError: + # swallowing agreed in #2 + pass + + def _get_loop(self) -> asyncio.AbstractEventLoop: + # Warning! + # The function should be called when self._sync_mutex is locked, + # otherwise the code is not thread-safe + loop = asyncio.get_running_loop() + + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + def close(self) -> None: with self._sync_mutex: self._closing = True @@ -134,7 +153,7 @@ async def wait_closed(self) -> None: raise RuntimeError("Waiting for non-closed queue") # give execution chances for the task-done callbacks # of async tasks created inside - # _notify_async_not_empty, _notify_async_not_full + # _make_async_not_empty_notifier, _make_async_not_full_notifier # methods. await asyncio.sleep(0) if not self._pending: @@ -188,8 +207,8 @@ def _sync_not_empty_notifier(self) -> None: with self._sync_mutex: self._sync_not_empty.notify() - def _notify_sync_not_empty(self) -> None: - fut = self._loop.run_in_executor(None, self._sync_not_empty_notifier) + def _notify_sync_not_empty(self, loop: asyncio.AbstractEventLoop) -> None: + fut = loop.run_in_executor(None, self._sync_not_empty_notifier) fut.add_done_callback(self._pending.remove) self._pending.append(fut) @@ -197,8 +216,8 @@ def _sync_not_full_notifier(self) -> None: with self._sync_mutex: self._sync_not_full.notify() - def _notify_sync_not_full(self) -> None: - fut = self._loop.run_in_executor(None, self._sync_not_full_notifier) + def _notify_sync_not_full(self, loop: asyncio.AbstractEventLoop) -> None: + fut = loop.run_in_executor(None, self._sync_not_full_notifier) fut.add_done_callback(self._pending.remove) self._pending.append(fut) @@ -206,32 +225,20 @@ async def _async_not_empty_notifier(self) -> None: async with self._async_mutex: self._async_not_empty.notify() - def _make_async_not_empty_notifier(self) -> None: - task = self._loop.create_task(self._async_not_empty_notifier()) + def _make_async_not_empty_notifier(self, loop: asyncio.AbstractEventLoop) -> None: + task = loop.create_task(self._async_not_empty_notifier()) task.add_done_callback(self._pending.remove) self._pending.append(task) - def _notify_async_not_empty(self, *, threadsafe: bool) -> None: - if threadsafe: - self._call_soon_threadsafe(self._make_async_not_empty_notifier) - else: - self._make_async_not_empty_notifier() - async def _async_not_full_notifier(self) -> None: async with self._async_mutex: self._async_not_full.notify() - def _make_async_not_full_notifier(self) -> None: - task = self._loop.create_task(self._async_not_full_notifier()) + def _make_async_not_full_notifier(self, loop: asyncio.AbstractEventLoop) -> None: + task = loop.create_task(self._async_not_full_notifier()) task.add_done_callback(self._pending.remove) self._pending.append(task) - def _notify_async_not_full(self, *, threadsafe: bool) -> None: - if threadsafe: - self._call_soon_threadsafe(self._make_async_not_full_notifier) - else: - self._make_async_not_full_notifier() - def _check_closing(self) -> None: if self._closing: raise RuntimeError("Operation on the closed queue is forbidden") @@ -276,7 +283,7 @@ def task_done(self) -> None: if unfinished < 0: raise ValueError("task_done() called too many times") parent._all_tasks_done.notify_all() - parent._loop.call_soon_threadsafe(parent._finished.set) + parent._call_soon_threadsafe(parent._finished.set) parent._unfinished_tasks = unfinished def join(self) -> None: @@ -356,10 +363,9 @@ def put(self, item: T, block: bool = True, timeout: OptFloat = None) -> None: elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: - time = parent._loop.time - endtime = time() + timeout + endtime = monotonic() + timeout while parent._qsize() >= parent._maxsize: - remaining = endtime - time() + remaining = endtime - monotonic() if remaining <= 0.0: raise SyncQueueFull parent._sync_not_full_waiting += 1 @@ -371,7 +377,10 @@ def put(self, item: T, block: bool = True, timeout: OptFloat = None) -> None: if parent._sync_not_empty_waiting: parent._sync_not_empty.notify() if parent._async_not_empty_waiting: - parent._notify_async_not_empty(threadsafe=True) + if parent._loop is not None: + parent._call_soon_threadsafe( + parent._make_async_not_empty_notifier, parent._loop + ) def get(self, block: bool = True, timeout: OptFloat = None) -> T: """Remove and return an item from the queue. @@ -400,10 +409,9 @@ def get(self, block: bool = True, timeout: OptFloat = None) -> T: elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: - time = parent._loop.time - endtime = time() + timeout + endtime = monotonic() + timeout while not parent._qsize(): - remaining = endtime - time() + remaining = endtime - monotonic() if remaining <= 0.0: raise SyncQueueEmpty parent._sync_not_empty_waiting += 1 @@ -415,7 +423,10 @@ def get(self, block: bool = True, timeout: OptFloat = None) -> T: if parent._sync_not_full_waiting: parent._sync_not_full.notify() if parent._async_not_full_waiting: - parent._notify_async_not_full(threadsafe=True) + if parent._loop is not None: + parent._call_soon_threadsafe( + parent._make_async_not_full_notifier, parent._loop + ) return item def put_nowait(self, item: T) -> None: @@ -446,21 +457,25 @@ def __init__(self, parent: Queue[T]): @property def closed(self) -> bool: - return self._parent.closed + parent = self._parent + return parent.closed def qsize(self) -> int: """Number of items in the queue.""" - return self._parent._qsize() + parent = self._parent + return parent._qsize() @property def unfinished_tasks(self) -> int: """Return the number of unfinished tasks.""" - return self._parent._unfinished_tasks + parent = self._parent + return parent._unfinished_tasks @property def maxsize(self) -> int: """Number of items allowed in the queue.""" - return self._parent._maxsize + parent = self._parent + return parent._maxsize def empty(self) -> bool: """Return True if the queue is empty, False otherwise.""" @@ -491,6 +506,7 @@ async def put(self, item: T) -> None: async with parent._async_not_full: parent._sync_mutex.acquire() locked = True + loop = parent._get_loop() try: if parent._maxsize > 0: do_wait = True @@ -511,7 +527,7 @@ async def put(self, item: T) -> None: if parent._async_not_empty_waiting: parent._async_not_empty.notify() if parent._sync_not_empty_waiting: - parent._notify_sync_not_empty() + parent._notify_sync_not_empty(loop) finally: if locked: parent._sync_mutex.release() @@ -524,15 +540,16 @@ def put_nowait(self, item: T) -> None: parent = self._parent parent._check_closing() with parent._sync_mutex: + loop = parent._get_loop() if parent._maxsize > 0: if parent._qsize() >= parent._maxsize: raise AsyncQueueFull parent._put_internal(item) if parent._async_not_empty_waiting: - parent._notify_async_not_empty(threadsafe=False) + parent._make_async_not_empty_notifier(loop) if parent._sync_not_empty_waiting: - parent._notify_sync_not_empty() + parent._notify_sync_not_empty(loop) async def get(self) -> T: """Remove and return an item from the queue. @@ -546,6 +563,7 @@ async def get(self) -> T: async with parent._async_not_empty: parent._sync_mutex.acquire() locked = True + loop = parent._get_loop() try: do_wait = True while do_wait: @@ -566,7 +584,7 @@ async def get(self) -> T: if parent._async_not_full_waiting: parent._async_not_full.notify() if parent._sync_not_full_waiting: - parent._notify_sync_not_full() + parent._notify_sync_not_full(loop) return item finally: if locked: @@ -583,11 +601,13 @@ def get_nowait(self) -> T: if parent._qsize() == 0: raise AsyncQueueEmpty + loop = parent._get_loop() + item = parent._get() if parent._async_not_full_waiting: - parent._notify_async_not_full(threadsafe=False) + parent._make_async_not_full_notifier(loop) if parent._sync_not_full_waiting: - parent._notify_sync_not_full() + parent._notify_sync_not_full(loop) return item def task_done(self) -> None: diff --git a/setup.cfg b/setup.cfg index ceb42b6..514a241 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,6 +47,10 @@ packages = find: zip_safe = True include_package_data = True +install_requires = + typing-extensions >= 4.0.0; python_version < "3.10" + + [flake8] exclude = .git,.env,__pycache__,.eggs max-line-length = 88 diff --git a/tests/test_mixed.py b/tests/test_mixed.py index 108ef50..cf4613c 100644 --- a/tests/test_mixed.py +++ b/tests/test_mixed.py @@ -1,4 +1,5 @@ import asyncio +import sys import threading from concurrent.futures import ThreadPoolExecutor @@ -9,10 +10,30 @@ class TestMixedMode: + @pytest.mark.skipif( + sys.version_info >= (3, 10), + reason="Python 3.10+ supports delayed initialization", + ) def test_ctor_noloop(self): with pytest.raises(RuntimeError): janus.Queue() + @pytest.mark.asyncio + async def test_get_loop_ok(self): + q = janus.Queue() + loop = asyncio.get_running_loop() + assert q._get_loop() is loop + assert q._loop is loop + + @pytest.mark.asyncio + async def test_get_loop_different_loop(self): + q = janus.Queue() + # emulate binding another loop + loop = q._loop = asyncio.new_event_loop() + with pytest.raises(RuntimeError, match="is bound to a different event loop"): + q._get_loop() + loop.close() + @pytest.mark.asyncio async def test_maxsize(self): q = janus.Queue(5) @@ -349,10 +370,7 @@ async def test_put_notifies_async_not_empty(self): loop = asyncio.get_running_loop() q = janus.Queue() - tasks = [ - loop.create_task(q.async_q.get()) - for _ in range(4) - ] + tasks = [loop.create_task(q.async_q.get()) for _ in range(4)] while q._async_not_empty_waiting != 4: await asyncio.sleep(0) @@ -395,10 +413,7 @@ async def test_get_notifies_async_not_full(self): q.sync_q.put_nowait(1) q.sync_q.put_nowait(2) - tasks = [ - loop.create_task(q.async_q.put(object())) - for _ in range(4) - ] + tasks = [loop.create_task(q.async_q.put(object())) for _ in range(4)] while q._async_not_full_waiting != 4: await asyncio.sleep(0) diff --git a/tests/test_sync.py b/tests/test_sync.py index 9ce53e2..f427c91 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -3,6 +3,7 @@ import asyncio import queue import re +import sys import threading import time from unittest.mock import patch @@ -423,3 +424,13 @@ async def test_closed_loop_non_failing(self): assert func.call_count == 1 _q.close() await _q.wait_closed() + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="Python 3.10+ is required", +) +def test_sync_only_api(): + q = janus.Queue() + q.sync_q.put(1) + assert q.sync_q.get() == 1