Skip to content

Commit

Permalink
Fix exceptions in websocket receive_* methods (#9511)
Browse files Browse the repository at this point in the history
Co-authored-by: J. Nick Koston <[email protected]>
  • Loading branch information
ara-25 and bdraco authored Nov 5, 2024
1 parent 37d9fe6 commit 75ae623
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGES/6800.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Modified websocket :meth:`aiohttp.ClientWebSocketResponse.receive_str`, :py:meth:`aiohttp.ClientWebSocketResponse.receive_bytes`, :py:meth:`aiohttp.web.WebSocketResponse.receive_str` & :py:meth:`aiohttp.web.WebSocketResponse.receive_bytes` methods to raise new :py:exc:`aiohttp.WSMessageTypeError` exception, instead of generic :py:exc:`TypeError`, when websocket messages of incorrect types are received -- by :user:`ara-25`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
- Contributors -
----------------
A. Jesse Jiryu Davis
Abdur Rehman Ali
Adam Bannister
Adam Cooper
Adam Horacek
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
TCPConnector,
TooManyRedirects,
UnixConnector,
WSMessageTypeError,
WSServerHandshakeError,
request,
)
Expand Down Expand Up @@ -228,6 +229,7 @@
# workers (imported lazily with __getattr__)
"GunicornUVLoopWebWorker",
"GunicornWebWorker",
"WSMessageTypeError",
)


Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
ServerTimeoutError,
SocketTimeoutError,
TooManyRedirects,
WSMessageTypeError,
WSServerHandshakeError,
)
from .client_reqrep import (
Expand Down Expand Up @@ -152,6 +153,7 @@
"ClientTimeout",
"ClientWSTimeout",
"request",
"WSMessageTypeError",
)


Expand Down
5 changes: 5 additions & 0 deletions aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"NonHttpUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlRedirectClientError",
"WSMessageTypeError",
)


Expand Down Expand Up @@ -377,3 +378,7 @@ def __str__(self) -> str:
"[{0.certificate_error.__class__.__name__}: "
"{0.certificate_error.args}]".format(self)
)


class WSMessageTypeError(TypeError):
"""WebSocket message type is not valid."""
10 changes: 7 additions & 3 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import TracebackType
from typing import Any, Final, Optional, Type

from .client_exceptions import ClientError, ServerTimeoutError
from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
from .client_reqrep import ClientResponse
from .helpers import calculate_timeout_when, set_result
from .http import (
Expand Down Expand Up @@ -379,13 +379,17 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not str")
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
)
return msg.data

async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.BINARY:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
)
return msg.data

async def receive_json(
Expand Down
11 changes: 6 additions & 5 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import hdrs
from ._websocket.writer import DEFAULT_LIMIT
from .abc import AbstractStreamWriter
from .client_exceptions import WSMessageTypeError
from .helpers import calculate_timeout_when, set_exception, set_result
from .http import (
WS_CLOSED_MESSAGE,
Expand Down Expand Up @@ -602,17 +603,17 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise TypeError(
"Received message {}:{!r} is not WSMsgType.TEXT".format(
msg.type, msg.data
)
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
)
return msg.data

async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.BINARY:
raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes")
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
)
return msg.data

async def receive_json(
Expand Down
10 changes: 8 additions & 2 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,7 @@ manually.

:return str: peer's message content.

:raise TypeError: if message is :const:`~aiohttp.WSMsgType.BINARY`.
:raise aiohttp.WSMessageTypeError: if message is not :const:`~aiohttp.WSMsgType.TEXT`.

.. method:: receive_bytes()
:async:
Expand All @@ -1684,7 +1684,7 @@ manually.

:return bytes: peer's message content.

:raise TypeError: if message is :const:`~aiohttp.WSMsgType.TEXT`.
:raise aiohttp.WSMessageTypeError: if message is not :const:`~aiohttp.WSMsgType.BINARY`.

.. method:: receive_json(*, loads=json.loads)
:async:
Expand Down Expand Up @@ -2239,6 +2239,12 @@ Response errors

Derived from :exc:`ClientResponseError`

.. exception:: WSMessageTypeError

Received WebSocket message of unexpected type

Derived from :exc:`TypeError`

Connection errors
^^^^^^^^^^^^^^^^^

Expand Down
4 changes: 2 additions & 2 deletions docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ and :ref:`aiohttp-web-signals` handlers::

:return str: peer's message content.

:raise TypeError: if message is :const:`~aiohttp.WSMsgType.BINARY`.
:raise aiohttp.WSMessageTypeError: if message is not :const:`~aiohttp.WSMsgType.TEXT`.

.. method:: receive_bytes(*, timeout=None)
:async:
Expand All @@ -1266,7 +1266,7 @@ and :ref:`aiohttp-web-signals` handlers::

:return bytes: peer's message content.

:raise TypeError: if message is :const:`~aiohttp.WSMsgType.TEXT`.
:raise aiohttp.WSMessageTypeError: if message is not :const:`~aiohttp.WSMsgType.BINARY`.

.. method:: receive_json(*, loads=json.loads, timeout=None)
:async:
Expand Down
55 changes: 52 additions & 3 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
import pytest

import aiohttp
from aiohttp import ClientConnectionResetError, ServerTimeoutError, WSMsgType, hdrs, web
from aiohttp import (
ClientConnectionResetError,
ServerTimeoutError,
WSMessageTypeError,
WSMsgType,
hdrs,
web,
)
from aiohttp.client_ws import ClientWSTimeout
from aiohttp.http import WSCloseCode
from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer
Expand Down Expand Up @@ -58,7 +65,28 @@ async def handler(request: web.Request) -> NoReturn:
resp = await client.ws_connect("/")
await resp.send_str("ask")

with pytest.raises(TypeError):
with pytest.raises(WSMessageTypeError):
await resp.receive_bytes()
await resp.close()


async def test_recv_bytes_after_close(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> NoReturn:
ws = web.WebSocketResponse()
await ws.prepare(request)

await ws.close()
assert False

app = web.Application()
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)
resp = await client.ws_connect("/")

with pytest.raises(
WSMessageTypeError,
match=f"Received message {WSMsgType.CLOSE}:.+ is not WSMsgType.BINARY",
):
await resp.receive_bytes()
await resp.close()

Expand Down Expand Up @@ -103,12 +131,33 @@ async def handler(request: web.Request) -> NoReturn:

await resp.send_bytes(b"ask")

with pytest.raises(TypeError):
with pytest.raises(WSMessageTypeError):
await resp.receive_str()

await resp.close()


async def test_recv_text_after_close(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> NoReturn:
ws = web.WebSocketResponse()
await ws.prepare(request)

await ws.close()
assert False

app = web.Application()
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)
resp = await client.ws_connect("/")

with pytest.raises(
WSMessageTypeError,
match=f"Received message {WSMsgType.CLOSE}:.+ is not WSMsgType.TEXT",
):
await resp.receive_str()
await resp.close()


async def test_send_recv_json(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
Expand Down
32 changes: 31 additions & 1 deletion tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from multidict import CIMultiDict
from pytest_mock import MockerFixture

from aiohttp import WSMsgType, web
from aiohttp import WSMessageTypeError, WSMsgType, web
from aiohttp.http import WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE
from aiohttp.http_websocket import WSMessageClose
from aiohttp.streams import EofStream
Expand Down Expand Up @@ -263,6 +263,21 @@ async def test_send_str_closed(make_request: _RequestMaker) -> None:
await ws.send_str("string")


async def test_recv_str_closed(make_request: _RequestMaker) -> None:
req = make_request("GET", "/")
ws = web.WebSocketResponse()
await ws.prepare(req)
assert ws._reader is not None
ws._reader.feed_data(WS_CLOSED_MESSAGE)
await ws.close()

with pytest.raises(
WSMessageTypeError,
match=f"Received message {WSMsgType.CLOSED}:.+ is not WSMsgType.TEXT",
):
await ws.receive_str()


async def test_send_bytes_closed(make_request: _RequestMaker) -> None:
req = make_request("GET", "/")
ws = web.WebSocketResponse()
Expand All @@ -275,6 +290,21 @@ async def test_send_bytes_closed(make_request: _RequestMaker) -> None:
await ws.send_bytes(b"bytes")


async def test_recv_bytes_closed(make_request: _RequestMaker) -> None:
req = make_request("GET", "/")
ws = web.WebSocketResponse()
await ws.prepare(req)
assert ws._reader is not None
ws._reader.feed_data(WS_CLOSED_MESSAGE)
await ws.close()

with pytest.raises(
WSMessageTypeError,
match=f"Received message {WSMsgType.CLOSED}:.+ is not WSMsgType.BINARY",
):
await ws.receive_bytes()


async def test_send_json_closed(make_request: _RequestMaker) -> None:
req = make_request("GET", "/")
ws = web.WebSocketResponse()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ async def handler(request: web.Request) -> NoReturn:
assert ws.can_prepare(request)

await ws.prepare(request)
await ws.send_bytes("answer") # type: ignore[arg-type]
await ws.send_str("answer")
assert False

app = web.Application()
Expand Down

0 comments on commit 75ae623

Please sign in to comment.