diff --git a/.gitignore b/.gitignore index 6414c1e2264..b89a8d9fa52 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ *.tar.gz *~ .DS_Store +.Python .coverage .idea .installed.cfg @@ -25,6 +26,9 @@ develop-eggs dist docs/_build/ eggs +include/ +lib/ +man/ nosetests.xml parts pyvenv diff --git a/aiohttp/protocol.py b/aiohttp/protocol.py index be9aac41bbb..08d6dbd89cd 100644 --- a/aiohttp/protocol.py +++ b/aiohttp/protocol.py @@ -186,10 +186,11 @@ def __call__(self, out, buf): # read headers headers, close, compression = self.parse_headers(lines) - if version <= HttpVersion10: - close = True - elif close is None: - close = False + if close is None: # then the headers weren't set in the request + if version <= HttpVersion10: # HTTP 1.0 must asks to not close + close = True + else: # HTTP 1.1 must ask to close. + close = False out.feed_data( RawRequestMessage( @@ -532,13 +533,7 @@ def __init__(self, transport, version, close): self.transport = transport self.version = version self.closing = close - - # disable keep-alive for http/1.0 - if version <= HttpVersion10: - self.keepalive = False - else: - self.keepalive = None - + self.keepalive = None self.chunked = False self.length = None self.headers = CIMultiDict() @@ -555,7 +550,16 @@ def enable_chunked_encoding(self): def keep_alive(self): if self.keepalive is None: - return not self.closing + if self.version < HttpVersion10: + # keep alive not supported at all + return False + if self.version == HttpVersion10: + if self.headers.get(hdrs.CONNECTION) == 'keep-alive': + return True + else: # no headers means we close for Http 1.0 + return False + else: + return not self.closing else: return self.keepalive @@ -591,7 +595,7 @@ def add_header(self, name, value): # connection keep-alive elif 'close' in val: self.keepalive = False - elif 'keep-alive' in val and self.version >= HttpVersion11: + elif 'keep-alive' in val: self.keepalive = True elif name == hdrs.UPGRADE: @@ -836,6 +840,11 @@ class Request(HttpMessage): def __init__(self, transport, method, path, http_version=HttpVersion11, close=False): + # set the default for HTTP 1.0 to be different + # will only be overwritten with keep-alive header + if http_version < HttpVersion11: + close = True + super().__init__(transport, http_version, close) self.method = method diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index f298d414d66..0cf0919ab1e 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -18,7 +18,7 @@ CIMultiDict, MultiDictProxy, MultiDict) -from .protocol import Response as ResponseImpl, HttpVersion11 +from .protocol import Response as ResponseImpl, HttpVersion10 from .streams import EOF_MARKER @@ -95,13 +95,10 @@ def __init__(self, app, message, payload, transport, reader, writer, *, self._post = None self._post_files_cache = None self._headers = CIMultiDictProxy(message.headers) - - if self._version < HttpVersion11: - self._keep_alive = False - elif message.should_close: + if self._version < HttpVersion10: self._keep_alive = False else: - self._keep_alive = True + self._keep_alive = not message.should_close # matchdict, route_name, handler # or information about traversal lookup diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 7e2e4b1de62..ab0b46cb88c 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -5,7 +5,7 @@ import unittest from aiohttp import web, request, FormData from aiohttp.multidict import MultiDict -from aiohttp.protocol import HttpVersion11 +from aiohttp.protocol import HttpVersion, HttpVersion10, HttpVersion11 from aiohttp.streams import EOF_MARKER @@ -462,3 +462,71 @@ def go(): self.assertEqual(403, resp.status) self.loop.run_until_complete(go()) + + def test_http10_keep_alive_default(self): + + @asyncio.coroutine + def handler(request): + yield from request.read() + return web.Response(body=b'OK') + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp = yield from request('GET', url, loop=self.loop, + version=HttpVersion10) + self.assertEqual('close', resp.headers['CONNECTION']) + + self.loop.run_until_complete(go()) + + def test_http09_keep_alive_default(self): + + @asyncio.coroutine + def handler(request): + yield from request.read() + return web.Response(body=b'OK') + + @asyncio.coroutine + def go(): + headers = {'Connection': 'keep-alive'} # should be ignored + _, _, url = yield from self.create_server('GET', '/', handler) + resp = yield from request('GET', url, loop=self.loop, + headers=headers, + version=HttpVersion(0, 9)) + self.assertEqual('close', resp.headers['CONNECTION']) + + self.loop.run_until_complete(go()) + + def test_http10_keep_alive_with_headers_close(self): + + @asyncio.coroutine + def handler(request): + yield from request.read() + return web.Response(body=b'OK') + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + headers = {'Connection': 'close'} + resp = yield from request('GET', url, loop=self.loop, + headers=headers, version=HttpVersion10) + self.assertEqual('close', resp.headers['CONNECTION']) + + self.loop.run_until_complete(go()) + + def test_http10_keep_alive_with_headers(self): + + @asyncio.coroutine + def handler(request): + yield from request.read() + return web.Response(body=b'OK') + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + headers = {'Connection': 'keep-alive'} + resp = yield from request('GET', url, loop=self.loop, + headers=headers, version=HttpVersion10) + self.assertEqual('keep-alive', resp.headers['CONNECTION']) + + self.loop.run_until_complete(go()) diff --git a/tests/test_web_request.py b/tests/test_web_request.py index eddcb3eb78f..257e2124262 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -18,6 +18,8 @@ def tearDown(self): def make_request(self, method, path, headers=CIMultiDict(), *, version=HttpVersion(1, 1), closing=False): + if version < HttpVersion(1, 1): + closing = True self.app = mock.Mock() message = RawRequestMessage(method, path, version, headers, closing, False) diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 0c6b1f4843a..522fb458f02 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -4,7 +4,8 @@ from aiohttp import hdrs from aiohttp.multidict import CIMultiDict from aiohttp.web import Request, StreamResponse, Response -from aiohttp.protocol import RawRequestMessage, HttpVersion11 +from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10 +from aiohttp.protocol import RawRequestMessage class TestStreamResponse(unittest.TestCase): @@ -17,9 +18,12 @@ def tearDown(self): self.loop.close() def make_request(self, method, path, headers=CIMultiDict()): - self.app = mock.Mock() message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) + return self.request_from_message(message) + + def request_from_message(self, message): + self.app = mock.Mock() self.payload = mock.Mock() self.transport = mock.Mock() self.reader = mock.Mock() @@ -347,6 +351,31 @@ def test___repr__not_started(self): resp = StreamResponse(reason=301) self.assertEqual("", repr(resp)) + def test_keep_alive_http10(self): + message = RawRequestMessage('GET', '/', HttpVersion10, CIMultiDict(), + True, False) + req = self.request_from_message(message) + resp = StreamResponse() + resp.start(req) + self.assertFalse(resp.keep_alive) + + headers = CIMultiDict(Connection='keep-alive') + message = RawRequestMessage('GET', '/', HttpVersion10, headers, + False, False) + req = self.request_from_message(message) + resp = StreamResponse() + resp.start(req) + self.assertEqual(resp.keep_alive, True) + + def test_keep_alive_http09(self): + headers = CIMultiDict(Connection='keep-alive') + message = RawRequestMessage('GET', '/', HttpVersion(0, 9), headers, + False, False) + req = self.request_from_message(message) + resp = StreamResponse() + resp.start(req) + self.assertFalse(resp.keep_alive) + class TestResponse(unittest.TestCase):