From 0f8f565175fad3dcf3fb83ee62042e582d32598b Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Tue, 16 May 2023 17:55:20 +0100 Subject: [PATCH] Allow timeout to work when reading with nowait (#5854) (#7292) (Note this depends on and extends #5853) When reading in a loop while the buffer is being constantly filled, the timeout does not work as there are no calls to `_wait()` where the timer is used. I don't know if this edge case is enough to be worried about, but have put together an initial attempt at fixing it. I'm not sure if this is really the right solution, but can atleast be used as as a discussion on ways to improve this. This can't be backported as this changes the public API (one of the functions is now async). Related #5851. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 80e2bde149e12754e8caa9d5380f960e16f7b9e3) ## What do these changes do? ## Are there changes in behavior for the user? ## Related issue number ## Checklist - [ ] I think the code is well written - [ ] Unit tests for the changes exist - [ ] Documentation reflects the changes - [ ] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` * The format is <Name> <Surname>. * Please keep alphabetical order, the file is sorted by names. - [ ] Add a new news fragment into the `CHANGES` folder * name it `.` for example (588.bugfix) * if you don't have an `issue_id` change it to the pr id after creating the pr * ensure type is one of the following: * `.feature`: Signifying a new feature. * `.bugfix`: Signifying a bug fix. * `.doc`: Signifying a documentation improvement. * `.removal`: Signifying a deprecation or removal of public API. * `.misc`: A ticket has been closed, but it is not of interest to users. * Make sure to use full sentences with correct case and punctuation, for example: "Fix issue with non-ascii contents in doctest text files." --- CHANGES/5854.bugfix | 1 + aiohttp/helpers.py | 8 +++++++- aiohttp/streams.py | 12 +++++------- tests/test_client_functional.py | 24 ++++++++++++++++++++++++ 4 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 CHANGES/5854.bugfix diff --git a/CHANGES/5854.bugfix b/CHANGES/5854.bugfix new file mode 100644 index 00000000000..b7de2f4d232 --- /dev/null +++ b/CHANGES/5854.bugfix @@ -0,0 +1 @@ +Fixed client timeout not working when incoming data is always available without waiting -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 4b7ce168e3b..47caeebcc93 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -682,7 +682,8 @@ def __call__(self) -> None: class BaseTimerContext(ContextManager["BaseTimerContext"]): - pass + def assert_timeout(self) -> None: + """Raise TimeoutError if timeout has been exceeded.""" class TimerNoop(BaseTimerContext): @@ -706,6 +707,11 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._tasks: List[asyncio.Task[Any]] = [] self._cancelled = False + def assert_timeout(self) -> None: + """Raise TimeoutError if timer has already been cancelled.""" + if self._cancelled: + raise asyncio.TimeoutError from None + def __enter__(self) -> BaseTimerContext: task = current_task(loop=self._loop) diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 32ec7f148f6..ea4bcd3c5a4 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -4,7 +4,7 @@ from typing import Awaitable, Callable, Deque, Generic, List, Optional, Tuple, TypeVar from .base_protocol import BaseProtocol -from .helpers import BaseTimerContext, set_exception, set_result +from .helpers import BaseTimerContext, TimerNoop, set_exception, set_result from .log import internal_logger from .typedefs import Final @@ -116,7 +116,7 @@ def __init__( self._waiter: Optional[asyncio.Future[None]] = None self._eof_waiter: Optional[asyncio.Future[None]] = None self._exception: Optional[BaseException] = None - self._timer = timer + self._timer = TimerNoop() if timer is None else timer self._eof_callbacks: List[Callable[[], None]] = [] def __repr__(self) -> str: @@ -291,10 +291,7 @@ async def _wait(self, func_name: str) -> None: waiter = self._waiter = self._loop.create_future() try: - if self._timer: - with self._timer: - await waiter - else: + with self._timer: await waiter finally: self._waiter = None @@ -485,8 +482,9 @@ def _read_nowait_chunk(self, n: int) -> bytes: def _read_nowait(self, n: int) -> bytes: """Read not more than n bytes, or whole buffer if n == -1""" - chunks = [] + self._timer.assert_timeout() + chunks = [] while self._buffer: chunk = self._read_nowait_chunk(n) chunks.append(chunk) diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 93f0dc06a14..81a80eca677 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -3034,6 +3034,30 @@ async def handler(request): await resp.read() +async def test_timeout_with_full_buffer(aiohttp_client) -> None: + async def handler(request): + """Server response that never ends and always has more data available.""" + resp = web.StreamResponse() + await resp.prepare(request) + while True: + await resp.write(b"1" * 1000) + await asyncio.sleep(0.01) + + async def request(client): + timeout = aiohttp.ClientTimeout(total=0.5) + async with await client.get("/", timeout=timeout) as resp: + with pytest.raises(asyncio.TimeoutError): + async for data in resp.content.iter_chunked(1): + await asyncio.sleep(0.01) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + # wait_for() used just to ensure that a failing test doesn't hang. + await asyncio.wait_for(request(client), 1) + + async def test_read_bufsize_session_default(aiohttp_client) -> None: async def handler(request): return web.Response(body=b"1234567")