Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ServerHttpProtocol refactoring #1060

Merged
merged 11 commits into from
Aug 19, 2016
6 changes: 1 addition & 5 deletions aiohttp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,7 @@ def __init__(self, line=''):
self.line = line


class ParserError(Exception):
"""Base parser error."""


class LineLimitExceededParserError(ParserError):
class LineLimitExceededParserError(HttpBadRequest):
"""Line is too long."""

def __init__(self, msg, limit):
Expand Down
214 changes: 75 additions & 139 deletions aiohttp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import socket
import traceback
import warnings
from contextlib import suppress
from html import escape as html_escape
from math import ceil

import aiohttp
from aiohttp import errors, hdrs, helpers, streams
from aiohttp.helpers import _get_kwarg, ensure_future
from aiohttp.helpers import Timeout, _get_kwarg, ensure_future
from aiohttp.log import access_logger, server_logger

__all__ = ('ServerHttpProtocol',)
Expand Down Expand Up @@ -53,15 +53,11 @@ class ServerHttpProtocol(aiohttp.StreamProtocol):

:param keepalive_timeout: number of seconds before closing
keep-alive connection
:type keepalive: int or None
:type keepalive_timeout: int or None

:param bool tcp_keepalive: TCP keep-alive is on, default is on

:param int timeout: slow request timeout

:param allowed_methods: (optional) List of allowed request methods.
Set to empty list to allow all methods.
:type allowed_methods: tuple
:param int slow_request_timeout: slow request timeout

:param bool debug: enable debug mode

Expand All @@ -85,9 +81,7 @@ class ServerHttpProtocol(aiohttp.StreamProtocol):
_request_count = 0
_request_handler = None
_reading_request = False
_keep_alive = False # keep transport open
_keep_alive_handle = None # keep alive timer handle
_slow_request_timeout_handle = None # slow request timer handle
_keepalive = False # keep transport open

def __init__(self, *, loop=None,
keepalive_timeout=75, # NGINX default value is 75 secs
Expand Down Expand Up @@ -138,6 +132,7 @@ def __init__(self, *, loop=None,
access_log_format)
else:
self.access_logger = None
self._closing = False

@property
def keep_alive_timeout(self):
Expand All @@ -150,57 +145,38 @@ def keep_alive_timeout(self):
def keepalive_timeout(self):
return self._keepalive_timeout

def closing(self, timeout=15.0):
@asyncio.coroutine
def shutdown(self, timeout=15.0):
"""Worker process is about to exit, we need cleanup everything and
stop accepting requests. It is especially important for keep-alive
connections."""
self._keep_alive = False
self._tcp_keep_alive = False
self._keepalive_timeout = None

if (not self._reading_request and self.transport is not None):
if self._request_handler:
self._request_handler.cancel()
self._request_handler = None

self.transport.close()
self.transport = None
elif self.transport is not None and timeout:
if self._slow_request_timeout_handle is not None:
self._slow_request_timeout_handle.cancel()

# use slow request timeout for closing
# connection_lost cleans timeout handler
now = self._loop.time()
self._slow_request_timeout_handle = self._loop.call_at(
ceil(now+timeout), self.cancel_slow_request)
if self._request_handler is None:
return
self._closing = True

if timeout:
canceller = self._loop.call_later(timeout,
self._request_handler.cancel)
with suppress(asyncio.CancelledError):
yield from self._request_handler
canceller.cancel()
else:
self._request_handler.cancel()

def connection_made(self, transport):
super().connection_made(transport)

self._request_handler = ensure_future(self.start(), loop=self._loop)

# start slow request timer
if self._slow_request_timeout:
now = self._loop.time()
self._slow_request_timeout_handle = self._loop.call_at(
ceil(now+self._slow_request_timeout), self.cancel_slow_request)

if self._tcp_keepalive:
tcp_keepalive(self, transport)

def connection_lost(self, exc):
super().connection_lost(exc)

self._closing = True
if self._request_handler is not None:
self._request_handler.cancel()
self._request_handler = None
if self._keep_alive_handle is not None:
self._keep_alive_handle.cancel()
self._keep_alive_handle = None
if self._slow_request_timeout_handle is not None:
self._slow_request_timeout_handle.cancel()
self._slow_request_timeout_handle = None

def data_received(self, data):
super().data_received(data)
Expand All @@ -209,17 +185,12 @@ def data_received(self, data):
if not self._reading_request:
self._reading_request = True

# stop keep-alive timer
if self._keep_alive_handle is not None:
self._keep_alive_handle.cancel()
self._keep_alive_handle = None

def keep_alive(self, val):
"""Set keep-alive connection mode.

:param bool val: new state.
"""
self._keep_alive = val
self._keepalive = val

def log_access(self, message, environ, response, time):
if self.access_logger:
Expand All @@ -233,16 +204,6 @@ def log_debug(self, *args, **kw):
def log_exception(self, *args, **kw):
self.logger.exception(*args, **kw)

def cancel_slow_request(self):
if self._request_handler is not None:
self._request_handler.cancel()
self._request_handler = None

if self.transport is not None:
self.transport.close()

self.log_debug('Close slow request.')

@asyncio.coroutine
def start(self):
"""Start processing of incoming requests.
Expand All @@ -255,44 +216,35 @@ def start(self):
"""
reader = self.reader

while True:
message = None
self._keep_alive = False
self._request_count += 1
self._reading_request = False

payload = None
try:
# read HTTP request method
prefix = reader.set_parser(self._request_prefix)
yield from prefix.read()

# start reading request
self._reading_request = True

# start slow request timer
if (self._slow_request_timeout and
self._slow_request_timeout_handle is None):
now = self._loop.time()
self._slow_request_timeout_handle = self._loop.call_at(
ceil(now+self._slow_request_timeout),
self.cancel_slow_request)

# read request headers
httpstream = reader.set_parser(self._request_parser)
message = yield from httpstream.read()

# cancel slow request timer
if self._slow_request_timeout_handle is not None:
self._slow_request_timeout_handle.cancel()
self._slow_request_timeout_handle = None
try:
while not self._closing:
message = None
self._keepalive = False
self._request_count += 1
self._reading_request = False

payload = None
with Timeout(max(self._slow_request_timeout,
self._keepalive_timeout),
loop=self._loop):
# read HTTP request method
prefix = reader.set_parser(self._request_prefix)
yield from prefix.read()

# start reading request
self._reading_request = True

# start slow request timer
# read request headers
httpstream = reader.set_parser(self._request_parser)
message = yield from httpstream.read()

# request may not have payload
try:
content_length = int(
message.headers.get(hdrs.CONTENT_LENGTH, 0))
except ValueError:
content_length = 0
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None

if (content_length > 0 or
message.method == 'CONNECT' or
Expand All @@ -308,55 +260,39 @@ def start(self):

yield from self.handle_request(message, payload)

except asyncio.CancelledError:
return
except errors.ClientDisconnectedError:
self.log_debug(
'Ignored premature client disconnection #1.')
return
except errors.HttpProcessingError as exc:
if self.transport is not None:
yield from self.handle_error(exc.code, message,
None, exc, exc.headers,
exc.message)
except errors.LineLimitExceededParserError as exc:
yield from self.handle_error(400, message, None, exc)
except Exception as exc:
yield from self.handle_error(500, message, None, exc)
finally:
if self.transport is None:
self.log_debug(
'Ignored premature client disconnection #2.')
return

if payload and not payload.is_eof():
self.log_debug('Uncompleted request.')
self._request_handler = None
self.transport.close()
return
self._closing = True
else:
reader.unset_parser()

if self._request_handler:
if self._keep_alive and self._keepalive_timeout:
self.log_debug(
'Start keep-alive timer for %s sec.',
self._keepalive_timeout)
now = self._loop.time()
self._keep_alive_handle = self._loop.call_at(
ceil(now+self._keepalive_timeout),
self.transport.close)
elif self._keep_alive:
# do nothing, rely on kernel or upstream server
pass
else:
self.log_debug('Close client connection.')
self._request_handler = None
self.transport.close()
return
else:
# connection is closed
return
if not self._keepalive or not self._keepalive_timeout:
self._closing = True

except asyncio.CancelledError:
self.log_debug(
'Request handler cancelled.')
return
except asyncio.TimeoutError:
self.log_debug(
'Request handler timed out.')
return
except errors.ClientDisconnectedError:
self.log_debug(
'Ignored premature client disconnection #1.')
return
except errors.HttpProcessingError as exc:
yield from self.handle_error(exc.code, message,
None, exc, exc.headers,
exc.message)
except Exception as exc:
yield from self.handle_error(500, message, None, exc)
finally:
self._request_handler = None
if self.transport is None:
self.log_debug(
'Ignored premature client disconnection #2.')
else:
self.transport.close()

def handle_error(self, status=500, message=None,
payload=None, exc=None, headers=None, reason=None):
Expand All @@ -366,7 +302,7 @@ def handle_error(self, status=500, message=None,
information. It always closes current connection."""
now = self._loop.time()
try:
if self._request_handler is None:
if self.transport is None:
# client has been disconnected during writing.
return ()

Expand Down
30 changes: 2 additions & 28 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,10 @@ def connection_lost(self, handler, exc=None):
if handler in self._connections:
del self._connections[handler]

@asyncio.coroutine
def _connections_cleanup(self):
sleep = 0.05
while self._connections:
yield from asyncio.sleep(sleep, loop=self._loop)
if sleep < 5:
sleep = sleep * 2

@asyncio.coroutine
def finish_connections(self, timeout=None):
# try to close connections in 90% of graceful timeout
timeout90 = None
if timeout:
timeout90 = timeout / 100 * 90

for handler in self._connections.keys():
handler.closing(timeout=timeout90)

if timeout:
try:
yield from asyncio.wait_for(
self._connections_cleanup(), timeout, loop=self._loop)
except asyncio.TimeoutError:
self._app.logger.warning(
"Not all connections are closed (pending: %d)",
len(self._connections))

for transport in self._connections.values():
transport.close()

coros = [conn.shutdown(timeout) for conn in self._connections]
yield from asyncio.gather(*coros, loop=self._loop)
self._connections.clear()

def __call__(self):
Expand Down
Loading