diff --git a/sanic/exceptions.py b/sanic/exceptions.py index b06c76d157..2c4ab2c02e 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -218,6 +218,11 @@ def __init__(self, message, content_range): } +@add_status_code(417) +class HeaderExpectationFailed(SanicException): + pass + + @add_status_code(403) class Forbidden(SanicException): pass diff --git a/sanic/request.py b/sanic/request.py index dfb3d1ffec..15c2d5c4e5 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -29,7 +29,7 @@ def json_loads(data): DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream" - +EXPECT_HEADER = "EXPECT" # HTTP/1.1: https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 # > If the media type remains unknown, the recipient SHOULD treat it diff --git a/sanic/server.py b/sanic/server.py index a2038e3c86..f8a9b20376 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -15,6 +15,7 @@ from multidict import CIMultiDict from sanic.exceptions import ( + HeaderExpectationFailed, InvalidUsage, PayloadTooLarge, RequestTimeout, @@ -22,7 +23,7 @@ ServiceUnavailable, ) from sanic.log import access_logger, logger -from sanic.request import Request, StreamBuffer +from sanic.request import EXPECT_HEADER, Request, StreamBuffer from sanic.response import HTTPResponse @@ -314,6 +315,10 @@ def on_headers_complete(self): if self._keep_alive_timeout_handler: self._keep_alive_timeout_handler.cancel() self._keep_alive_timeout_handler = None + + if self.request.headers.get(EXPECT_HEADER): + self.expect_handler() + if self.is_request_stream: self._is_stream_handler = self.router.is_stream_handler( self.request @@ -324,6 +329,21 @@ def on_headers_complete(self): ) self.execute_request_handler() + def expect_handler(self): + """ + Handler for Expect Header. + """ + expect = self.request.headers.get(EXPECT_HEADER) + if self.request.version == "1.1": + if expect.lower() == "100-continue": + self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") + else: + self.write_error( + HeaderExpectationFailed( + "Unknown Expect: {expect}".format(expect=expect) + ) + ) + def on_body(self, body): if self.is_request_stream and self._is_stream_handler: self._request_stream_task = self.loop.create_task( diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index d845dc8507..70fa621b81 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -1,4 +1,6 @@ +import pytest from sanic.blueprints import Blueprint +from sanic.exceptions import HeaderExpectationFailed from sanic.request import StreamBuffer from sanic.response import stream, text from sanic.views import CompositionView, HTTPMethodView @@ -40,6 +42,38 @@ async def post(self, request): assert response.text == data +@pytest.mark.parametrize("headers, expect_raise_exception", [ +({"EXPECT": "100-continue"}, False), +({"EXPECT": "100-continue-extra"}, True), +]) +def test_request_stream_100_continue(app, headers, expect_raise_exception): + class SimpleView(HTTPMethodView): + + @stream_decorator + async def post(self, request): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + app.add_route(SimpleView.as_view(), "/method_view") + + assert app.is_request_stream is True + + if not expect_raise_exception: + request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"}) + assert response.status == 200 + assert response.text == data + else: + with pytest.raises(ValueError) as e: + app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"}) + assert "Unknown Expect: 100-continue-extra" in str(e) + + def test_request_stream_app(app): """for self.is_request_stream = True and decorators"""