diff --git a/janus/__init__.py b/janus/__init__.py index 45f60da..42f48ef 100644 --- a/janus/__init__.py +++ b/janus/__init__.py @@ -1,13 +1,16 @@ import asyncio import sys import threading +from concurrent.futures import ThreadPoolExecutor, Future, CancelledError +from time import time as time_time +from asyncio import AbstractEventLoop from asyncio import QueueEmpty as AsyncQueueEmpty from asyncio import QueueFull as AsyncQueueFull from collections import deque from heapq import heappop, heappush from queue import Empty as SyncQueueEmpty from queue import Full as SyncQueueFull -from typing import Any, Callable, Deque, Generic, List, Optional, Set, TypeVar +from typing import Any, Callable, Deque, Generic, List, Optional, Set, TypeVar, Union, Tuple from typing_extensions import Protocol @@ -24,6 +27,137 @@ T = TypeVar("T") OptFloat = Optional[float] +PostAsyncInit = Optional[T] + + +class InitAsyncPartsMixin: + @property + def already_initialized(self) -> bool: + """Indicate that instance already initialized""" + raise NotImplementedError() + + @property + def _also_initialize_when_triggered(self) -> List["InitAsyncPartsMixin"]: + """Returns a list of objects whose async parts must also be initialized.""" + return [] + + @property + def _list_of_methods_to_patch(self) -> List[Tuple[str, str]]: + """Return list of ('cur_methods', 'new_method') for monkey-patching + + List of methods whose behavior has been changed to be use without initializing the async parts + """ + return [] + + def _async_post_init_patch_methods(self): + """Monkey patching""" + for method_name, new_method in ((cm, getattr(self, nm)) for cm, nm in self._list_of_methods_to_patch): + setattr(self, method_name, new_method) + + def _async_post_init_handler(self, loop: Optional[AbstractEventLoop] = None, **params) -> Optional[AbstractEventLoop]: + """Handle initializing of asynchronous parts of object""" + if self.already_initialized: + return loop + + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + raise RuntimeError("Async parts of 'janus' must be initialized only from running loop. " + "Do not call async from sync code") + + # 'already_initialized' must be True after call + self._async_init(loop, **params) + self._async_post_init_patch_methods() + + for need_init in self._also_initialize_when_triggered: + if not isinstance(need_init, InitAsyncPartsMixin): + raise ValueError("'_also_initialize_when_triggered' must contain only instances of " + "class that inherited from 'InitAsyncPartsMixin'") + need_init._async_post_init_handler(loop) + + return loop + + def _async_init(self, loop: AbstractEventLoop, **params): + """Override to change behavior + + The actions of this function should affect the "value" of 'already_initialized' (set it to True) + """ + ... + + def trigger_async_initialization(self, **params): + """Trigger initialization of async parts + + Public alias for '_async_post_init_handler' + """ + return self._async_post_init_handler(**params) + + +class PreInitDummyLoop: + """Replacement for a 'Queue._loop', until the async part is fully initialized""" + + def __init__(self): + self.executor = ThreadPoolExecutor(thread_name_prefix="PreInitDummyLoop-") + self.pending = set() # type: Set[Future[Any]] + + @staticmethod + def time(): + """Replacement of '_loop.time' in '_SyncQueueProxy.get' and '_SyncQueueProxy.put'""" + return time_time() + + def call_soon_threadsafe(self, callback: Callable[..., None]): + future = self.executor.submit(callback) + self.pending.add(future) + future.add_done_callback(self.pending.discard) + + def run_in_executor(self, callback: Callable[..., None]): + future = self.executor.submit(callback) + self.pending.add(future) + future.add_done_callback(self.pending.discard) + + def wait(self): + for task in self.pending: + try: + task.result() + except CancelledError: + ... + + def cleanup(self): + for task in self.pending: + task.cancel() + + +class PreInitDummyAsyncQueue: + """Replacement of 'Queue.async_q' + + Will trigger initialization of async part, on every access to attrs. + If, after full initialization, someone has a link to it, it starts working as a proxy, + redirecting everything to the actual 'async_q' + """ + + def __init__(self, trigger_obj: "Queue[T]"): + self.trigger_obj = trigger_obj + self.already_triggered = threading.Event() + + def __getattribute__(self, item): + already_triggered = super().__getattribute__("already_triggered") # type: threading.Event + trigger_obj = super().__getattribute__("trigger_obj") # type: Queue[T] + + if already_triggered.is_set(): + async_q = getattr(trigger_obj, "async_q") # type: Union[_AsyncQueueProxy[T], PreInitDummyAsyncQueue] + + if not isinstance(async_q, _AsyncQueueProxy): + raise RuntimeError("Async parts multi-initialization detected. You cannot access 'async_q' attrs " + "until full initialization is complete") + return getattr(async_q, item) + + already_triggered.set() + trigger_obj.trigger_async_initialization() + async_q = getattr(trigger_obj, "async_q") # type: _AsyncQueueProxy[T] + + if isinstance(async_q, PreInitDummyAsyncQueue): + raise RuntimeError("Error during post initializing. 'async_q' must be replaced with actual 'AsyncQueue'") + return getattr(async_q, item) class BaseQueue(Protocol[T]): @@ -81,11 +215,13 @@ async def join(self) -> None: ... -class Queue(Generic[T]): - def __init__(self, maxsize: int = 0) -> None: - self._loop = asyncio.get_running_loop() +class Queue(Generic[T], InitAsyncPartsMixin): + def __init__(self, maxsize: int = 0, init_async_part: bool = True) -> None: self._maxsize = maxsize + # will be set after the async part is initialized + self.full_init = threading.Event() + self._init(maxsize) self._unfinished_tasks = 0 @@ -95,20 +231,60 @@ def __init__(self, maxsize: int = 0) -> None: self._sync_not_full = threading.Condition(self._sync_mutex) self._all_tasks_done = threading.Condition(self._sync_mutex) + self._closing = False + self._pending = set() # type: Set[asyncio.Future[Any]] + + self._loop = PreInitDummyLoop() # type: Union[PreInitDummyLoop, AbstractEventLoop] + + self._async_mutex = asyncio.Lock() # type: PostAsyncInit[asyncio.Lock] + if sys.version_info[:3] == (3, 10, 0): + # Workaround for Python 3.10 bug, see #358: + getattr(self._async_mutex, "_get_loop", lambda: None)() + self._async_not_empty = None # type: PostAsyncInit[asyncio.Condition] + self._async_not_full = None # type: PostAsyncInit[asyncio.Condition] + # set 'threading.Event' to not change behavior + self._finished = threading.Event() # type: Union[asyncio.Event, threading.Event] + + def before_init_async_parts_dummy_handler( + callback: Callable[..., None], *args: Any + ) -> None: + callback(*args) + + self._call_soon_threadsafe = before_init_async_parts_dummy_handler + + self._call_soon = before_init_async_parts_dummy_handler + + self._sync_queue = _SyncQueueProxy(self) + self._async_queue = PreInitDummyAsyncQueue(self) # type: Union[PreInitDummyAsyncQueue, "_AsyncQueueProxy[T]"] + + if init_async_part: + self.trigger_async_initialization() + + @property + def already_initialized(self) -> bool: + """Return True if all parts of 'Queue'(sync/async) are initialized""" + return self.full_init.is_set() + + def _async_init(self, loop: AbstractEventLoop, **params): + self._loop = loop + + self._async_queue = _AsyncQueueProxy(self) self._async_mutex = asyncio.Lock() if sys.version_info[:3] == (3, 10, 0): # Workaround for Python 3.10 bug, see #358: getattr(self._async_mutex, "_get_loop", lambda: None)() self._async_not_empty = asyncio.Condition(self._async_mutex) self._async_not_full = asyncio.Condition(self._async_mutex) + + _finished = self._finished self._finished = asyncio.Event() self._finished.set() - self._closing = False - self._pending = set() # type: Set[asyncio.Future[Any]] + if not _finished.is_set(): + _finished.set() def checked_call_soon_threadsafe( - callback: Callable[..., None], *args: Any + callback: Callable[..., None], *args: Any ) -> None: try: self._loop.call_soon_threadsafe(callback, *args) @@ -124,14 +300,23 @@ def checked_call_soon(callback: Callable[..., None], *args: Any) -> None: self._call_soon = checked_call_soon - self._sync_queue = _SyncQueueProxy(self) - self._async_queue = _AsyncQueueProxy(self) + self.full_init.set() + + @property + def _list_of_methods_to_patch(self) -> List[Tuple[str, str]]: + return [ + ("_notify_sync_condition", "_post_async_init_notify_sync_condition"), + ("_notify_async_condition", "_post_async_init_notify_async_condition"), + ] def close(self) -> None: with self._sync_mutex: self._closing = True - for fut in self._pending: - fut.cancel() + if isinstance(self._loop, PreInitDummyLoop): + self._loop.cleanup() + else: + for fut in self._pending: + fut.cancel() self._finished.set() # unblocks all async_q.join() self._all_tasks_done.notify_all() # unblocks all sync_q.join() @@ -146,9 +331,13 @@ async def wait_closed(self) -> None: # _notify_async_not_empty, _notify_async_not_full # methods. await asyncio.sleep(0) - if not self._pending: - return - await asyncio.wait(self._pending) + + if isinstance(self._loop, PreInitDummyLoop): + self._loop.wait() + else: + if not self._pending: + return + await asyncio.wait(self._pending) @property def closed(self) -> bool: @@ -163,7 +352,7 @@ def sync_q(self) -> "_SyncQueueProxy[T]": return self._sync_queue @property - def async_q(self) -> "_AsyncQueueProxy[T]": + def async_q(self) -> Union[PreInitDummyAsyncQueue, "_AsyncQueueProxy[T]"]: return self._async_queue # Override these methods to implement other queue organizations @@ -189,26 +378,37 @@ def _put_internal(self, item: T) -> None: self._unfinished_tasks += 1 self._finished.clear() - def _notify_sync_not_empty(self) -> None: + def _post_async_init_notify_sync_condition(self, condition: asyncio.Condition) -> None: + """ Replacement for '_notify_sync_condition', after initializing the async parts """ def f() -> None: with self._sync_mutex: - self._sync_not_empty.notify() + condition.notify() + + fut = asyncio.ensure_future(self._loop.run_in_executor(None, f), loop=self._loop) + fut.add_done_callback(self._pending.discard) + self._pending.add(fut) - self._loop.run_in_executor(None, f) + def _notify_sync_condition(self, condition: asyncio.Condition) -> None: + """A single interface for notifying sync conditions""" + loop = self._loop # type: PreInitDummyLoop - def _notify_sync_not_full(self) -> None: def f() -> None: with self._sync_mutex: - self._sync_not_full.notify() + condition.notify() - fut = asyncio.ensure_future(self._loop.run_in_executor(None, f)) - fut.add_done_callback(self._pending.discard) - self._pending.add(fut) + loop.run_in_executor(f) - def _notify_async_not_empty(self, *, threadsafe: bool) -> None: + def _notify_sync_not_empty(self) -> None: + self._notify_sync_condition(self._sync_not_empty) + + def _notify_sync_not_full(self) -> None: + self._notify_sync_condition(self._sync_not_full) + + def _post_async_init_notify_async_condition(self, condition: asyncio.Condition, threadsafe: bool): + """ Replacement for '_notify_async_condition', after initializing the async parts """ async def f() -> None: async with self._async_mutex: - self._async_not_empty.notify() + condition.notify() def task_maker() -> None: task = self._loop.create_task(f()) @@ -220,20 +420,17 @@ def task_maker() -> None: else: self._call_soon(task_maker) - def _notify_async_not_full(self, *, threadsafe: bool) -> None: - async def f() -> None: - async with self._async_mutex: - self._async_not_full.notify() + def _notify_async_condition(self, condition: asyncio.Condition, threadsafe: bool): + """A single interface for notifying async conditions - def task_maker() -> None: - task = self._loop.create_task(f()) - task.add_done_callback(self._pending.discard) - self._pending.add(task) + Useless until async parts are not initialized""" + ... - if threadsafe: - self._call_soon_threadsafe(task_maker) - else: - self._call_soon(task_maker) + def _notify_async_not_empty(self, *, threadsafe: bool) -> None: + self._notify_async_condition(self._async_not_empty, threadsafe) + + def _notify_async_not_full(self, *, threadsafe: bool) -> None: + self._notify_async_condition(self._async_not_full, threadsafe) def _check_closing(self) -> None: if self._closing: @@ -271,6 +468,10 @@ def task_done(self) -> None: Raises a ValueError if called more times than there were items placed in the queue. """ + def f(): + with self._parent._all_tasks_done: + self._parent._finished.set() + self._parent._check_closing() with self._parent._all_tasks_done: unfinished = self._parent._unfinished_tasks - 1 @@ -278,7 +479,7 @@ def task_done(self) -> None: if unfinished < 0: raise ValueError("task_done() called too many times") self._parent._all_tasks_done.notify_all() - self._parent._loop.call_soon_threadsafe(self._parent._finished.set) + self._parent._loop.call_soon_threadsafe(f) self._parent._unfinished_tasks = unfinished def join(self) -> None: diff --git a/tests/test_post_init.py b/tests/test_post_init.py new file mode 100644 index 0000000..7288b20 --- /dev/null +++ b/tests/test_post_init.py @@ -0,0 +1,296 @@ +import threading +import asyncio +import janus +import pytest + +from concurrent.futures import ThreadPoolExecutor +from typing import Any + + +class TestOnlySync: + tpe = ThreadPoolExecutor() + + def test_only_sync_init(self): + queue: janus.Queue[int] = janus.Queue(init_async_part=False) + assert not queue.already_initialized + + def test_only_sync_work(self): + queue: janus.Queue[int] = janus.Queue(1, init_async_part=False) + queue.sync_q.put(1) + assert queue.sync_q.get() == 1 + queue.sync_q.task_done() + + def test_only_sync_get_two_threads_put(self): + queue: janus.Queue[int] = janus.Queue(2, init_async_part=False) + queue.sync_q.put(1) + + def put_some_n_times(n, sync_q: janus.SyncQueue): + for i in range(n): + sync_q.put(i) + + a_n = 5 + b_n = 7 + + a_f = self.tpe.submit(put_some_n_times, a_n, queue.sync_q) + b_f = self.tpe.submit(put_some_n_times, b_n, queue.sync_q) + + actual_n = 0 + while a_n + b_n > actual_n: + queue.sync_q.get(timeout=3) + queue.sync_q.task_done() + actual_n += 1 + + a_f.result() + b_f.result() + + assert a_n + b_n == actual_n + + def test_sync_attempt_to_full_init(self): + with pytest.raises(RuntimeError): + janus.Queue(init_async_part=True) + + def test_sync_attempt_to_post_init_0(self): + queue: janus.Queue[Any] = janus.Queue(init_async_part=False) + with pytest.raises(RuntimeError): + queue.trigger_async_initialization() + + def test_sync_attempt_to_post_init_1(self): + queue: janus.Queue[Any] = janus.Queue(init_async_part=False) + with pytest.raises(RuntimeError): + queue.async_q.qsize() + + +class TestSyncThenPostInitAsync: + tpe = ThreadPoolExecutor() + + def test_sync_then_async_0(self): + queue: janus.Queue[Any] = janus.Queue(init_async_part=False) + + async def init_async(queue_: janus.Queue[Any]): + queue_.trigger_async_initialization() + + loop = asyncio.get_event_loop() + loop.run_until_complete(init_async(queue)) + + def test_sync_then_async_1(self): + queue: janus.Queue[Any] = janus.Queue(init_async_part=False) + + async def init_async(async_q: janus.AsyncQueue[Any]): + async_q.qsize() + + loop = asyncio.get_event_loop() + loop.run_until_complete(init_async(queue.async_q)) + + def test_double_init(self): + queue: janus.Queue[Any] = janus.Queue(init_async_part=False) + + async def init_async(): + queue.trigger_async_initialization() + queue.trigger_async_initialization() + queue.async_q.empty() + assert isinstance(queue.async_q, janus._AsyncQueueProxy) + + loop = asyncio.get_event_loop() + loop.run_until_complete(init_async()) + + def test_full_init_with_wrong_coinitializers(self): + class Queue(janus.Queue[Any]): + @property + def _also_initialize_when_triggered(self): + return [None] + + queue: Queue = Queue(init_async_part=False) + + async def init_async(): + with pytest.raises(ValueError): + queue.trigger_async_initialization() + + loop = asyncio.get_event_loop() + loop.run_until_complete(init_async()) + + def test_sync_put_then_async_get(self): + _data = [i for i in range(10)] + queue: janus.Queue[Any] = janus.Queue(maxsize=len(_data), init_async_part=False) + + for i in _data: + queue.sync_q.put(i) + assert queue.sync_q.qsize() == i + 1 + + async def init_async(async_q: janus.AsyncQueue[Any]): + it = iter(_data) + + while not async_q.empty(): + i = await async_q.get() + async_q.task_done() + assert i == next(it) + + loop = asyncio.get_event_loop() + loop.run_until_complete(init_async(queue.async_q)) + queue.close() + loop.run_until_complete(queue.wait_closed()) + + def test_producer_threads_wait_until_init(self): + num_of_threads = 3 + items = 6 + + queue: janus.Queue[Any] = janus.Queue(maxsize=num_of_threads * items, init_async_part=False) + barrier = threading.Barrier(num_of_threads + 1) + + def put_something_after_init(thr_num): + queue.full_init.wait() + + for i in range(1, items + 1): + queue.sync_q.put((thr_num, i * thr_num)) + + barrier.wait() + + for thr_num in range(1, num_of_threads + 1): + self.tpe.submit(put_something_after_init, thr_num) + + async def init_async(async_q: janus.AsyncQueue[Any]): + while barrier.parties - 1 > barrier.n_waiting or not async_q.empty(): + thread_num, num = await async_q.get() + async_q.task_done() + assert num / (num / thread_num) == thread_num + + barrier.wait() + assert async_q.empty() + + loop = asyncio.get_event_loop() + loop.run_until_complete(init_async(queue.async_q)) + queue.close() + loop.run_until_complete(queue.wait_closed()) + + def test_consumer_threads_wait_until_init(self): + num_of_threads = 5 + items = 6 + + total_items = num_of_threads * items + + queue: janus.Queue[Any] = janus.Queue(maxsize=total_items, init_async_part=False) + barrier = threading.Barrier(num_of_threads + 1) + start = threading.Event() + lock = threading.Lock() + last_exception = [] + + def get_something_after_init(): + queue.full_init.wait() + start.wait() + + while not queue.sync_q.empty(): + with lock: + if queue.sync_q.empty(): + break + try: + queue.sync_q.get(block=False) + queue.sync_q.task_done() + except Exception as E: + last_exception.append(E) + break + + barrier.wait() + + for thr_num in range(num_of_threads): + self.tpe.submit(get_something_after_init) + + async def init_async(async_q: janus.AsyncQueue[Any]): + for i in range(total_items): + await async_q.put(i) + + start.set() + barrier.wait() + + if last_exception: + raise last_exception.pop() + + assert async_q.empty() + + loop = asyncio.get_event_loop() + loop.run_until_complete(init_async(queue.async_q)) + queue.close() + loop.run_until_complete(queue.wait_closed()) + + def test_consumer_concurrent_threads_wait_until_init(self): + num_of_threads = 6 + items = 7 + + total_items = num_of_threads * items + + queue: janus.Queue[Any] = janus.Queue(maxsize=total_items, init_async_part=False) + barrier = threading.Barrier(num_of_threads + 1) + start = threading.Event() + lock = threading.Lock() + + def do_wit_lock(func): + def wrapper(self): + with lock: + return func(self) + return wrapper + + counter = type( + "Counter", (object,), + { + "c": 0, + "get": do_wit_lock(lambda self: self.c), + "increase": do_wit_lock(lambda self: setattr(self, "c", self.c + 1)) + } + )() + + def get_something_after_init(): + queue.full_init.wait() + start.wait() + + while counter.get() < total_items: + try: + queue.sync_q.get(block=False) + queue.sync_q.task_done() + except janus.SyncQueueEmpty: + ... + + counter.increase() + + barrier.wait() + + for thr_num in range(num_of_threads): + self.tpe.submit(get_something_after_init) + + async def init_async(async_q: janus.AsyncQueue[Any]): + for i in range(total_items): + await async_q.put(i) + + start.set() + barrier.wait() + + loop = asyncio.get_event_loop() + loop.run_until_complete(init_async(queue.async_q)) + queue.close() + loop.run_until_complete(queue.wait_closed()) + + def test_async_producers_sync_consumer(self): + num_of_producers = 6 + items = 7 + + total_items = num_of_producers * items + + queue: janus.Queue[Any] = janus.Queue(maxsize=total_items, init_async_part=False) + + async def producer(async_q: janus.AsyncQueue, prod_num: int): + for i in range(items): + await async_q.put((prod_num, i)) + + fut = asyncio.gather( + *(producer(queue.async_q, cor_num) for cor_num in range(num_of_producers)) + ) + + loop = asyncio.get_event_loop() + loop.run_until_complete(fut) + + actual_count = 0 + while not queue.sync_q.empty(): + queue.sync_q.get() + queue.sync_q.task_done() + actual_count += 1 + + assert actual_count == total_items + + queue.close()