diff --git a/Makefile b/Makefile index 64d3911..4d3aa67 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ fmt: cov: flake develop pytest --cov=janus --cov=tests --cov-report=term --cov-report=html - @echo "open file://`pwd`/htmlcov/index.html" + @echo "open file://`pwd`/coverage/index.html" checkrst: python setup.py check --restructuredtext diff --git a/janus/__init__.py b/janus/__init__.py index b53cad6..a070156 100644 --- a/janus/__init__.py +++ b/janus/__init__.py @@ -574,7 +574,7 @@ async def get(self) -> T: parent = self._parent async with parent._async_not_empty: with parent._sync_mutex: - if parent._is_shutdown: + if parent._is_shutdown and not parent._qsize(): raise AsyncQueueShutDown parent._get_loop() # check the event loop while not parent._qsize(): @@ -585,7 +585,7 @@ async def get(self) -> T: finally: parent._sync_mutex.acquire() parent._async_not_empty_waiting -= 1 - if parent._is_shutdown: + if parent._is_shutdown and not parent._qsize(): raise AsyncQueueShutDown item = parent._get() @@ -602,7 +602,7 @@ def get_nowait(self) -> T: """ parent = self._parent with parent._sync_mutex: - if parent._is_shutdown: + if parent._is_shutdown and not parent._qsize(): raise AsyncQueueShutDown if not parent._qsize(): raise AsyncQueueEmpty diff --git a/tests/test_async.py b/tests/test_async.py index 093e05d..f13841c 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -446,6 +446,145 @@ async def put(): await close(_q) +class TestQueueShutdown: + @pytest.mark.asyncio + async def test_shutdown_empty(self): + _q = janus.Queue() + q = _q.async_q + + q.shutdown() + with pytest.raises(janus.AsyncQueueShutDown): + await q.put("data") + with pytest.raises(janus.AsyncQueueShutDown): + await q.get() + with pytest.raises(janus.AsyncQueueShutDown): + q.get_nowait() + + @pytest.mark.asyncio + async def test_shutdown_nonempty(self): + _q = janus.Queue() + q = _q.async_q + + await q.put("data") + q.shutdown() + await q.get() + with pytest.raises(janus.AsyncQueueShutDown): + await q.get() + + @pytest.mark.asyncio + async def test_shutdown_nonempty_get_nowait(self): + _q = janus.Queue() + q = _q.async_q + + await q.put("data") + q.shutdown() + q.get_nowait() + with pytest.raises(janus.AsyncQueueShutDown): + q.get_nowait() + + @pytest.mark.asyncio + async def test_shutdown_immediate(self): + _q = janus.Queue() + q = _q.async_q + + await q.put("data") + q.shutdown(immediate=True) + with pytest.raises(janus.AsyncQueueShutDown): + await q.get() + with pytest.raises(janus.AsyncQueueShutDown): + q.get_nowait() + + @pytest.mark.asyncio + async def test_shutdown_immediate_with_undone_tasks(self): + _q = janus.Queue() + q = _q.async_q + + await q.put(1) + await q.put(2) + # artificial .task_done() without .get() for covering specific codeline + # in .shutdown(True) + q.task_done() + + q.shutdown(True) + await close(_q) + + @pytest.mark.asyncio + async def test_shutdown_putter(self): + _q = janus.Queue(maxsize=1) + q = _q.async_q + + await q.put(1) + + async def putter(): + await q.put(2) + + task = asyncio.create_task(putter()) + # wait for the task start + await asyncio.sleep(0.01) + + q.shutdown() + + with pytest.raises(janus.AsyncQueueShutDown): + await task + + await close(_q) + + @pytest.mark.asyncio + async def test_shutdown_many_putters(self): + _q = janus.Queue(maxsize=1) + q = _q.async_q + + await q.put(1) + + async def putter(n): + await q.put(n) + + tasks = [] + for i in range(2): + tasks.append(asyncio.create_task(putter(i))) + # wait for the task start + await asyncio.sleep(0.01) + + q.shutdown() + + for task in tasks: + with pytest.raises(janus.AsyncQueueShutDown): + await task + + await close(_q) + + @pytest.mark.asyncio + async def test_shutdown_getter(self): + _q = janus.Queue() + q = _q.async_q + + async def getter(): + await q.get() + + task = asyncio.create_task(getter()) + # wait for the task start + await asyncio.sleep(0.01) + + q.shutdown() + + with pytest.raises(janus.AsyncQueueShutDown): + await task + + await close(_q) + + @pytest.mark.asyncio + async def test_shutdown_early_getter(self): + _q = janus.Queue() + q = _q.async_q + + q.shutdown() + + with pytest.raises(janus.AsyncQueueShutDown): + await q.get() + + await close(_q) + + class TestLifoQueue: @pytest.mark.asyncio async def test_order(self): diff --git a/tests/test_mixed.py b/tests/test_mixed.py index a81645d..a0c02da 100644 --- a/tests/test_mixed.py +++ b/tests/test_mixed.py @@ -1,6 +1,5 @@ import asyncio import sys -import threading from concurrent.futures import ThreadPoolExecutor @@ -383,3 +382,19 @@ async def test_get_notifies_async_not_full(self): await asyncio.gather(*tasks) assert q.sync_q.qsize() == 2 await q.aclose() + + @pytest.mark.asyncio + async def test_wait_closed_with_pending_tasks(self): + q = janus.Queue() + + async def getter(): + await q.async_q.get() + + task = asyncio.create_task(getter()) + await asyncio.sleep(0.01) + q.shutdown() + # q._pending is not empty now + await q.wait_closed() + + with pytest.raises(janus.AsyncQueueShutDown): + await task diff --git a/tests/test_sync.py b/tests/test_sync.py index 330e9f2..161ab84 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -434,3 +434,189 @@ def test_sync_only_api(): q = janus.Queue() q.sync_q.put(1) assert q.sync_q.get() == 1 + + +class TestQueueShutdown: + @pytest.mark.asyncio + async def test_shutdown_empty(self): + _q = janus.Queue() + q = _q.sync_q + + q.shutdown() + with pytest.raises(janus.SyncQueueShutDown): + q.put("data") + with pytest.raises(janus.SyncQueueShutDown): + q.get() + with pytest.raises(janus.SyncQueueShutDown): + q.get_nowait() + + @pytest.mark.asyncio + async def test_shutdown_nonempty(self): + _q = janus.Queue() + q = _q.sync_q + + q.put("data") + q.shutdown() + q.get() + with pytest.raises(janus.SyncQueueShutDown): + q.get() + + @pytest.mark.asyncio + async def test_shutdown_nonempty_get_nowait(self): + _q = janus.Queue() + q = _q.sync_q + + q.put("data") + q.shutdown() + q.get_nowait() + with pytest.raises(janus.SyncQueueShutDown): + q.get_nowait() + + @pytest.mark.asyncio + async def test_shutdown_immediate(self): + _q = janus.Queue() + q = _q.sync_q + + q.put("data") + q.shutdown(immediate=True) + with pytest.raises(janus.SyncQueueShutDown): + q.get() + with pytest.raises(janus.SyncQueueShutDown): + q.get_nowait() + + @pytest.mark.asyncio + async def test_shutdown_immediate_with_undone_tasks(self): + _q = janus.Queue() + q = _q.sync_q + + q.put(1) + q.put(2) + # artificial .task_done() without .get() for covering specific codeline + # in .shutdown(True) + q.task_done() + + q.shutdown(True) + + @pytest.mark.asyncio + async def test_shutdown_putter(self): + loop = asyncio.get_running_loop() + _q = janus.Queue(maxsize=1) + q = _q.sync_q + + q.put(1) + + def putter(): + q.put(2) + + fut = loop.run_in_executor(None, putter) + # wait for the task start + await asyncio.sleep(0.01) + + q.shutdown() + + with pytest.raises(janus.SyncQueueShutDown): + await fut + + await _q.aclose() + + @pytest.mark.asyncio + async def test_shutdown_many_putters(self): + loop = asyncio.get_running_loop() + _q = janus.Queue(maxsize=1) + q = _q.sync_q + + q.put(1) + + def putter(n): + q.put(n) + + futs = [] + for i in range(2): + futs.append(loop.run_in_executor(None, putter, i)) + # wait for the task start + await asyncio.sleep(0.01) + + q.shutdown() + + for fut in futs: + with pytest.raises(janus.SyncQueueShutDown): + await fut + + await _q.aclose() + + @pytest.mark.asyncio + async def test_shutdown_many_putters_with_timeout(self): + loop = asyncio.get_running_loop() + _q = janus.Queue(maxsize=1) + q = _q.sync_q + + q.put(1) + + def putter(n): + q.put(n, timeout=60) + + futs = [] + for i in range(2): + futs.append(loop.run_in_executor(None, putter, i)) + # wait for the task start + await asyncio.sleep(0.01) + + q.shutdown() + + for fut in futs: + with pytest.raises(janus.SyncQueueShutDown): + await fut + + await _q.aclose() + + @pytest.mark.asyncio + async def test_shutdown_getter(self): + loop = asyncio.get_running_loop() + _q = janus.Queue() + q = _q.sync_q + + def getter(): + q.get() + + fut = loop.run_in_executor(None, getter) + # wait for the task start + await asyncio.sleep(0.01) + + q.shutdown() + + with pytest.raises(janus.SyncQueueShutDown): + await fut + + await _q.aclose() + + @pytest.mark.asyncio + async def test_shutdown_getter_with_timeout(self): + loop = asyncio.get_running_loop() + _q = janus.Queue() + q = _q.sync_q + + def getter(): + q.get(timeout=60) + + fut = loop.run_in_executor(None, getter) + # wait for the task start + await asyncio.sleep(0.01) + + q.shutdown() + + with pytest.raises(janus.SyncQueueShutDown): + await fut + + await _q.aclose() + + @pytest.mark.asyncio + async def test_shutdown_early_getter(self): + _q = janus.Queue() + q = _q.sync_q + + q.shutdown() + + with pytest.raises(janus.SyncQueueShutDown): + q.get() + + await _q.aclose()