diff --git a/CHANGES/3980.feature b/CHANGES/3980.feature new file mode 100644 index 00000000000..89654c156c4 --- /dev/null +++ b/CHANGES/3980.feature @@ -0,0 +1 @@ +Accept non-GET request for starting websocket handshake on server side. diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 8e81eaaa3f3..4c5e7ca5529 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -26,7 +26,7 @@ from .log import ws_logger from .streams import EofStream, FlowControlDataQueue from .typedefs import JSONDecoder, JSONEncoder -from .web_exceptions import HTTPBadRequest, HTTPException, HTTPMethodNotAllowed +from .web_exceptions import HTTPBadRequest, HTTPException from .web_request import BaseRequest from .web_response import StreamResponse @@ -130,8 +130,6 @@ def _handshake(self, request: BaseRequest) -> Tuple['CIMultiDict[str]', bool, bool]: headers = request.headers - if request.method != hdrs.METH_GET: - raise HTTPMethodNotAllowed(request.method, [hdrs.METH_GET]) if 'websocket' != headers.get(hdrs.UPGRADE, '').lower().strip(): raise HTTPBadRequest( text=('No WebSocket UPGRADE hdr: {}\n Can ' diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 2d810c85e60..cab47b01066 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -8,7 +8,7 @@ from aiohttp.log import ws_logger from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro, make_mocked_request -from aiohttp.web import HTTPBadRequest, HTTPMethodNotAllowed, WebSocketResponse +from aiohttp.web import HTTPBadRequest, WebSocketResponse from aiohttp.web_ws import WS_CLOSED_MESSAGE, WebSocketReady @@ -203,12 +203,6 @@ def test_can_prepare_unknown_protocol(make_request) -> None: assert WebSocketReady(True, None) == ws.can_prepare(req) -def test_can_prepare_invalid_method(make_request) -> None: - req = make_request('POST', '/') - ws = WebSocketResponse() - assert WebSocketReady(False, None) == ws.can_prepare(req) - - def test_can_prepare_without_upgrade(make_request) -> None: req = make_request('GET', '/', headers=CIMultiDict({})) @@ -302,11 +296,11 @@ async def test_close_idempotent(make_request) -> None: assert not (await ws.close(code=2, message='message2')) -async def test_prepare_invalid_method(make_request) -> None: +async def test_prepare_post_method_ok(make_request) -> None: req = make_request('POST', '/') ws = WebSocketResponse() - with pytest.raises(HTTPMethodNotAllowed): - await ws.prepare(req) + await ws.prepare(req) + assert ws.prepared async def test_prepare_without_upgrade(make_request) -> None: diff --git a/tests/test_websocket_handshake.py b/tests/test_websocket_handshake.py index 21852b3c290..1ab709e3cf7 100644 --- a/tests/test_websocket_handshake.py +++ b/tests/test_websocket_handshake.py @@ -32,13 +32,6 @@ def gen_ws_headers(protocols='', compress=0, extension_text='', return hdrs, key -async def test_not_get() -> None: - ws = web.WebSocketResponse() - req = make_mocked_request('POST', '/') - with pytest.raises(web.HTTPMethodNotAllowed): - await ws.prepare(req) - - async def test_no_upgrade() -> None: ws = web.WebSocketResponse() req = make_mocked_request('GET', '/')