diff --git a/CHANGES/7180.bugfix b/CHANGES/7180.bugfix new file mode 100644 index 00000000000..66980638868 --- /dev/null +++ b/CHANGES/7180.bugfix @@ -0,0 +1 @@ +``ConnectionResetError`` will always be raised when ``StreamWriter.write`` is called after ``connection_lost`` has been called on the ``BaseProtocol`` diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index 8189835e211..4c9f0a752e3 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -18,11 +18,15 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop: asyncio.AbstractEventLoop = loop self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - self._connection_lost = False self._reading_paused = False self.transport: Optional[asyncio.Transport] = None + @property + def connected(self) -> bool: + """Return True if the connection is open.""" + return self.transport is not None + def pause_writing(self) -> None: assert not self._paused self._paused = True @@ -59,7 +63,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = tr def connection_lost(self, exc: Optional[BaseException]) -> None: - self._connection_lost = True # Wake up the writer if currently paused. self.transport = None if not self._paused: @@ -76,7 +79,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: waiter.set_exception(exc) async def _drain_helper(self) -> None: - if self._connection_lost: + if not self.connected: raise ConnectionResetError("Connection lost") if not self._paused: return diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index db3d6a04897..73f0f96f0ae 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -35,7 +35,6 @@ def __init__( on_headers_sent: _T_OnHeadersSent = None, ) -> None: self._protocol = protocol - self._transport = protocol.transport self.loop = loop self.length = None @@ -52,7 +51,7 @@ def __init__( @property def transport(self) -> Optional[asyncio.Transport]: - return self._transport + return self._protocol.transport @property def protocol(self) -> BaseProtocol: @@ -71,10 +70,10 @@ def _write(self, chunk: bytes) -> None: size = len(chunk) self.buffer_size += size self.output_size += size - - if self._transport is None or self._transport.is_closing(): + transport = self.transport + if not self._protocol.connected or transport is None or transport.is_closing(): raise ConnectionResetError("Cannot write to closing transport") - self._transport.write(chunk) + transport.write(chunk) async def write( self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 @@ -159,7 +158,6 @@ async def write_eof(self, chunk: bytes = b"") -> None: await self.drain() self._eof = True - self._transport = None async def drain(self) -> None: """Flush the write buffer. diff --git a/tests/test_base_protocol.py b/tests/test_base_protocol.py index f3b966bff54..a16b1f10cb1 100644 --- a/tests/test_base_protocol.py +++ b/tests/test_base_protocol.py @@ -45,10 +45,10 @@ async def test_connection_lost_not_paused() -> None: pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) - assert not pr._connection_lost + assert pr.connected pr.connection_lost(None) assert pr.transport is None - assert pr._connection_lost + assert not pr.connected async def test_connection_lost_paused_without_waiter() -> None: @@ -56,11 +56,11 @@ async def test_connection_lost_paused_without_waiter() -> None: pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) - assert not pr._connection_lost + assert pr.connected pr.pause_writing() pr.connection_lost(None) assert pr.transport is None - assert pr._connection_lost + assert not pr.connected async def test_drain_lost() -> None: diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 85225c77dad..eea2830246a 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -134,3 +134,17 @@ async def test_eof_received(loop) -> None: assert proto._read_timeout_handle is not None proto.eof_received() assert proto._read_timeout_handle is None + + +async def test_connection_lost_sets_transport_to_none(loop, mocker) -> None: + """Ensure that the transport is set to None when the connection is lost. + + This ensures the writer knows that the connection is closed. + """ + proto = ResponseHandler(loop=loop) + proto.connection_made(mocker.Mock()) + assert proto.transport is not None + + proto.connection_lost(OSError()) + + assert proto.transport is None diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 8ebcfc654a5..5649f32f792 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -236,6 +236,21 @@ async def test_write_to_closing_transport(protocol, transport, loop) -> None: await msg.write(b"After closing") +async def test_write_to_closed_transport(protocol, transport, loop) -> None: + """Test that writing to a closed transport raises ConnectionResetError. + + The StreamWriter checks to see if protocol.transport is None before + writing to the transport. If it is None, it raises ConnectionResetError. + """ + msg = http.StreamWriter(protocol, loop) + + await msg.write(b"Before transport close") + protocol.transport = None + + with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"): + await msg.write(b"After transport closed") + + async def test_drain(protocol, transport, loop) -> None: msg = http.StreamWriter(protocol, loop) await msg.drain()