From 66c7033d81df63afb7dde6c7fabd3d8d626d68e5 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sun, 30 Nov 2014 22:31:37 -0800 Subject: [PATCH] gunicorn worker for aiohttp.web --- CHANGES.txt | 13 ++++ aiohttp/log.py | 1 + aiohttp/server.py | 18 ++++-- aiohttp/web.py | 56 +++++++++++++++-- aiohttp/worker.py | 119 +++++++++-------------------------- examples/web_srv.py | 2 +- tests/test_web.py | 38 +++++++++++ tests/test_worker.py | 147 +++++++++---------------------------------- 8 files changed, 179 insertions(+), 215 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index fbfcd1758b4..24f76600ae0 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,6 +1,14 @@ CHANGES ======= +Unreleased +------------------- + +- Gunicorn worker for aiohttp.web + +- Removed deprecated AsyncGunicornWorker + + 0.11.0 (11-29-2014) ------------------- @@ -8,11 +16,13 @@ CHANGES - Make websocket subprotocols conform to spec #181 + 0.10.2 (11-19-2014) ------------------- - Don't unquote `environ['PATH_INFO']` in wsgi.py #177 + 0.10.1 (11-17-2014) ------------------- @@ -21,6 +31,7 @@ CHANGES - Fix multidict `__iter__`, the method should iterate over keys, not (key, value) pairs. + 0.10.0 (11-13-2014) ------------------- @@ -38,11 +49,13 @@ CHANGES - Set server.transport to None on .closing() #172 + 0.9.3 (10-30-2014) ------------------ - Fix compatibility with asyncio 3.4.1+ #170 + 0.9.2 (10-16-2014) ------------------ diff --git a/aiohttp/log.py b/aiohttp/log.py index 47ddbacb51c..4f8fc984751 100644 --- a/aiohttp/log.py +++ b/aiohttp/log.py @@ -5,4 +5,5 @@ client_log = logging.getLogger('aiohttp.client') internal_log = logging.getLogger('aiohttp.internal') server_log = logging.getLogger('aiohttp.server') +web_log = logging.getLogger('aiohttp.web') websocket_log = logging.getLogger('aiohttp.websocket') diff --git a/aiohttp/server.py b/aiohttp/server.py index 6560bdfa44f..b0893ff29d6 100644 --- a/aiohttp/server.py +++ b/aiohttp/server.py @@ -72,10 +72,18 @@ class ServerHttpProtocol(aiohttp.StreamProtocol): _request_parser = aiohttp.HttpRequestParser() # default request parser - def __init__(self, *, loop=None, keep_alive=None, - timeout=15, tcp_keepalive=True, allowed_methods=(), - debug=False, log=server_log, access_log=access_log, - access_log_format=ACCESS_LOG_FORMAT, **kwargs): + def __init__(self, *, loop=None, + keep_alive=None, + timeout=15, + tcp_keepalive=True, + allowed_methods=(), + log=server_log, + access_log=access_log, + access_log_format=ACCESS_LOG_FORMAT, + host="", + port=0, + debug=False, + **kwargs): super().__init__(loop=loop, **kwargs) self._keep_alive_period = keep_alive # number of seconds to keep alive @@ -84,6 +92,8 @@ def __init__(self, *, loop=None, keep_alive=None, self._request_prefix = aiohttp.HttpPrefixParser(allowed_methods) self._loop = loop if loop is not None else asyncio.get_event_loop() + self.host = host + self.port = port self.log = log self.debug = debug self.access_log = access_log diff --git a/aiohttp/web.py b/aiohttp/web.py index 91f62d55848..eee6510b4a3 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -13,6 +13,7 @@ from urllib.parse import urlsplit, parse_qsl, unquote, urlencode from .abc import AbstractRouter, AbstractMatchInfo +from .log import web_log from .multidict import (CaseInsensitiveMultiDict, CaseInsensitiveMutableMultiDict, MultiDict, @@ -1100,6 +1101,16 @@ def __init__(self, app, **kwargs): super().__init__(**kwargs) self._app = app + def connection_made(self, transport): + super().connection_made(transport) + + self._app.connection_made(self, transport) + + def connection_lost(self, exc): + self._app.connection_lost(self, exc) + + super().connection_lost(exc) + @asyncio.coroutine def handle_request(self, message, payload): request = Request(self._app, message, payload, @@ -1136,13 +1147,13 @@ def __init__(self, *, loop=None, router=None, **kwargs): # TODO: explicitly accept *debug* param if loop is None: loop = asyncio.get_event_loop() - self._kwargs = kwargs if router is None: router = UrlDispatcher() assert isinstance(router, AbstractRouter), router self._router = router self._loop = loop self._finish_callbacks = [] + self._connections = {} @property def router(self): @@ -1152,12 +1163,26 @@ def router(self): def loop(self): return self._loop - def make_handler(self): - return RequestHandler(self, loop=self._loop, **self._kwargs) + def make_handler(self, **kwargs): + return RequestHandler(self, loop=self._loop, **kwargs) + + @property + def connections(self): + return list(self._connections.keys()) + + def connection_made(self, handler, transport): + self._connections[handler] = transport + + def connection_lost(self, handler, exc=None): + if handler in self._connections: + del self._connections[handler] @asyncio.coroutine def finish(self): - for (cb, args, kwargs) in self._finish_callbacks: + callbacks = self._finish_callbacks + self._finish_callbacks = [] + + for (cb, args, kwargs) in callbacks: try: res = cb(*args, **kwargs) if (asyncio.iscoroutine(res) or @@ -1170,5 +1195,28 @@ def finish(self): 'application': self, }) + @asyncio.coroutine + def finish_connections(self, timeout=None): + for handler in self._connections.keys(): + handler.closing() + + def cleanup(): + while self._connections: + yield from asyncio.sleep(0.5, loop=self._loop) + + if timeout: + try: + yield from asyncio.wait_for( + cleanup(), timeout, loop=self._loop) + except asyncio.TimeoutError: + web_log.warn( + "Not all connections closed (pending: %d)", + len(self._connections)) + + for transport in self._connections.values(): + transport.close() + + self._connections.clear() + def register_on_finish(self, func, *args, **kwargs): self._finish_callbacks.insert(0, (func, args, kwargs)) diff --git a/aiohttp/worker.py b/aiohttp/worker.py index baa1995f050..253e8b6705a 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -1,24 +1,18 @@ -"""Async gunicorn worker.""" -__all__ = ['AsyncGunicornWorker', 'PortMapperWorker'] +"""Async gunicorn worker for auihttp.wen.Application.""" +__all__ = ['GunicornWebWorker'] import asyncio import functools import os import gunicorn.workers.base as base -import warnings -from aiohttp.wsgi import WSGIServerHttpProtocol - -class AsyncGunicornWorker(base.Worker): +class GunicornWebWorker(base.Worker): def __init__(self, *args, **kw): # pragma: no cover - warnings.warn("AsyncGunicornWorker is deprecated " - "starting from 0.11 release, " - "use standard gaiohttp worker.", DeprecationWarning) super().__init__(*args, **kw) + self.servers = [] - self.connections = {} def init_process(self): # create new event_loop after fork @@ -37,33 +31,36 @@ def run(self): finally: self.loop.close() - def wrap_protocol(self, proto): - proto.connection_made = _wrp( - proto, proto.connection_made, self.connections) - proto.connection_lost = _wrp( - proto, proto.connection_lost, self.connections, False) - return proto - - def factory(self, wsgi, host, port): - proto = WSGIServerHttpProtocol( - wsgi, loop=self.loop, + def factory(self, app, host, port): + return app.make_handler( + host=host, + port=port, log=self.log, debug=self.cfg.debug, keep_alive=self.cfg.keepalive, access_log=self.log.access_log, access_log_format=self.cfg.access_log_format) - return self.wrap_protocol(proto) def get_factory(self, sock, host, port): return functools.partial(self.factory, self.wsgi, host, port) @asyncio.coroutine def close(self): - try: - if hasattr(self.wsgi, 'close'): - yield from self.wsgi.close() - except: - self.log.exception('Process shutdown exception') + if self.servers: + self.log.info("Stopping server: %s, connections: %s", + self.pid, len(self.wsgi.connections)) + + # stop accepting connections + for server in self.servers: + server.close() + self.servers.clear() + + # stop alive connections + yield from self.wsgi.finish_connections( + timeout=self.cfg.graceful_timeout / 100 * 80) + + # stop application + yield from self.wsgi.finish() @asyncio.coroutine def _run(self): @@ -75,73 +72,15 @@ def _run(self): # If our parent changed then we shut down. pid = os.getpid() try: - while self.alive or self.connections: + while self.alive: self.notify() - if (self.alive and - pid == os.getpid() and self.ppid != os.getppid()): - self.log.info("Parent changed, shutting down: %s", self) + if pid == os.getpid() and self.ppid != os.getppid(): self.alive = False - - # stop accepting requests - if not self.alive: - if self.servers: - self.log.info( - "Stopping server: %s, connections: %s", - pid, len(self.connections)) - for server in self.servers: - server.close() - self.servers.clear() - - # prepare connections for closing - for conn in self.connections.values(): - if hasattr(conn, 'closing'): - conn.closing() - - yield from asyncio.sleep(1.0, loop=self.loop) - except KeyboardInterrupt: + self.log.info("Parent changed, shutting down: %s", self) + else: + yield from asyncio.sleep(1.0, loop=self.loop) + except (Exception, BaseException, GeneratorExit, KeyboardInterrupt): pass - if self.servers: - for server in self.servers: - server.close() - yield from self.close() - - -class PortMapperWorker(AsyncGunicornWorker): - """Special worker that uses different wsgi application depends on port. - - Main wsgi application object has to be dictionary: - """ - - def get_factory(self, sock, host, port): - return functools.partial(self.factory, self.wsgi[port], host, port) - - @asyncio.coroutine - def close(self): - for port, wsgi in self.wsgi.items(): - try: - if hasattr(wsgi, 'close'): - yield from wsgi.close() - except: - self.log.exception('Process shutdown exception') - - -class _wrp: - - def __init__(self, proto, meth, tracking, add=True): - self._proto = proto - self._id = id(proto) - self._meth = meth - self._tracking = tracking - self._add = add - - def __call__(self, *args): - if self._add: - self._tracking[self._id] = self._proto - elif self._id in self._tracking: - del self._tracking[self._id] - - conn = self._meth(*args) - return conn diff --git a/examples/web_srv.py b/examples/web_srv.py index 9bea6f12b76..5a0d0709a1f 100755 --- a/examples/web_srv.py +++ b/examples/web_srv.py @@ -32,7 +32,7 @@ def change_body(request): @asyncio.coroutine def hello(request): resp = StreamResponse(request) - name = request.match_info.get('name', 'Anonimous') + name = request.match_info.get('name', 'Anonymous') answer = ('Hello, ' + name).encode('utf8') resp.content_length = len(answer) resp.send_headers() diff --git a/tests/test_web.py b/tests/test_web.py index 7523b1010cf..38d0376f2bb 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -82,3 +82,41 @@ def test_non_default_router(self): router = web.UrlDispatcher() app = web.Application(loop=self.loop, router=router) self.assertIs(router, app.router) + + def test_connections(self): + app = web.Application(loop=self.loop) + self.assertEqual(app.connections, []) + + handler = object() + transport = object() + app.connection_made(handler, transport) + self.assertEqual(app.connections, [handler]) + + app.connection_lost(handler, None) + self.assertEqual(app.connections, []) + + def test_finish_connection_no_timeout(self): + app = web.Application(loop=self.loop) + handler = mock.Mock() + transport = mock.Mock() + app.connection_made(handler, transport) + + self.loop.run_until_complete(app.finish_connections()) + + app.connection_lost(handler, None) + self.assertEqual(app.connections, []) + handler.closing.assert_called_with() + transport.close.assert_called_with() + + def test_finish_connection_timeout(self): + app = web.Application(loop=self.loop) + handler = mock.Mock() + transport = mock.Mock() + app.connection_made(handler, transport) + + self.loop.run_until_complete(app.finish_connections(timeout=0.1)) + + app.connection_lost(handler, None) + self.assertEqual(app.connections, []) + handler.closing.assert_called_with() + transport.close.assert_called_with() diff --git a/tests/test_worker.py b/tests/test_worker.py index 7c46e78e49f..fd6297a486d 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -8,13 +8,13 @@ except ImportError as error: raise unittest.SkipTest('gunicorn required') from error -from aiohttp.wsgi import WSGIServerHttpProtocol - -class TestWorker(worker.AsyncGunicornWorker): +class TestWorker(worker.GunicornWebWorker): def __init__(self): - self.connections = {} + self.servers = [] + self.cfg = unittest.mock.Mock() + self.cfg.graceful_timeout = 100 class WorkerTests(unittest.TestCase): @@ -31,7 +31,7 @@ def tearDown(self): def test_init_process(self, m_asyncio): try: self.worker.init_process() - except AttributeError: + except TypeError: pass self.assertTrue(m_asyncio.get_event_loop.return_value.close.called) @@ -55,7 +55,7 @@ def test_factory(self): f = self.worker.factory( self.worker.wsgi, 'localhost', 8080) - self.assertIsInstance(f, WSGIServerHttpProtocol) + self.assertIs(f, self.worker.wsgi.make_handler.return_value) @unittest.mock.patch('aiohttp.worker.asyncio') def test__run(self, m_asyncio): @@ -66,6 +66,9 @@ def test__run(self, m_asyncio): sock.cfg_addr = ('localhost', 8080) self.worker.sockets = [sock] self.worker.wsgi = unittest.mock.Mock() + self.worker.close = unittest.mock.Mock() + self.worker.close.return_value = asyncio.Future(loop=self.loop) + self.worker.close.return_value.set_result(()) self.worker.log = unittest.mock.Mock() self.worker.notify = unittest.mock.Mock() loop = self.worker.loop = unittest.mock.Mock() @@ -77,31 +80,6 @@ def test__run(self, m_asyncio): self.assertTrue(self.worker.log.info.called) self.assertTrue(self.worker.notify.called) - def test__run_connections(self): - conn = unittest.mock.Mock() - self.worker.ppid = 1 - self.worker.alive = False - self.worker.servers = [unittest.mock.Mock()] - self.worker.connections = {1: conn} - self.worker.sockets = [] - self.worker.wsgi = unittest.mock.Mock() - self.worker.log = unittest.mock.Mock() - self.worker.loop = self.loop - self.worker.loop.create_server = unittest.mock.Mock() - self.worker.notify = unittest.mock.Mock() - - def _close_conns(): - yield from asyncio.sleep(0.1, loop=self.loop) - self.worker.connections = {} - - asyncio.async(_close_conns(), loop=self.loop) - self.loop.run_until_complete(self.worker._run()) - - self.assertTrue(self.worker.log.info.called) - self.assertTrue(self.worker.notify.called) - self.assertFalse(self.worker.servers) - self.assertTrue(conn.closing.called) - @unittest.mock.patch('aiohttp.worker.os') @unittest.mock.patch('aiohttp.worker.asyncio.sleep') def test__run_exc(self, m_sleep, m_os): @@ -120,92 +98,29 @@ def test__run_exc(self, m_sleep, m_os): slp.set_exception(KeyboardInterrupt) m_sleep.return_value = slp + self.worker.close = unittest.mock.Mock() + self.worker.close.return_value = asyncio.Future(loop=self.loop) + self.worker.close.return_value.set_result(1) + self.loop.run_until_complete(self.worker._run()) self.assertTrue(m_sleep.called) - self.assertTrue(self.worker.servers[0].close.called) + self.assertTrue(self.worker.close.called) - def test_close_wsgi_app(self): - self.worker.ppid = 1 - self.worker.alive = False - self.worker.servers = [unittest.mock.Mock()] - self.worker.connections = {} - self.worker.sockets = [] + def test_close(self): + srv = unittest.mock.Mock() + self.worker.servers = [srv] self.worker.log = unittest.mock.Mock() - self.worker.loop = self.loop - self.worker.loop.create_server = unittest.mock.Mock() - self.worker.notify = unittest.mock.Mock() - - self.worker.wsgi = unittest.mock.Mock() - self.worker.wsgi.close.return_value = asyncio.Future(loop=self.loop) - self.worker.wsgi.close.return_value.set_result(1) - - self.loop.run_until_complete(self.worker._run()) - self.assertTrue(self.worker.wsgi.close.called) - - self.worker.wsgi = unittest.mock.Mock() - self.worker.wsgi.close.return_value = asyncio.Future(loop=self.loop) - self.worker.wsgi.close.return_value.set_exception(ValueError()) - - self.loop.run_until_complete(self.worker._run()) - self.assertTrue(self.worker.wsgi.close.called) - - def test_portmapper_worker(self): - wsgi = {1: object(), 2: object()} - - class Worker(worker.PortMapperWorker): - - def __init__(self, wsgi): - self.wsgi = wsgi - - def factory(self, wsgi, host, port): - return wsgi - - w = Worker(wsgi) - self.assertIs( - wsgi[1], w.get_factory(object(), '', 1)()) - self.assertIs( - wsgi[2], w.get_factory(object(), '', 2)()) - - def test_portmapper_close_wsgi_app(self): - - class Worker(worker.PortMapperWorker): - def __init__(self, wsgi): - self.wsgi = wsgi - - wsgi = {1: unittest.mock.Mock(), 2: unittest.mock.Mock()} - wsgi[1].close.return_value = asyncio.Future(loop=self.loop) - wsgi[1].close.return_value.set_result(1) - wsgi[2].close.return_value = asyncio.Future(loop=self.loop) - wsgi[2].close.return_value.set_exception(ValueError()) - - w = Worker(wsgi) - w.ppid = 1 - w.alive = False - w.servers = [unittest.mock.Mock()] - w.connections = {} - w.sockets = [] - w.log = unittest.mock.Mock() - w.loop = self.loop - w.loop.create_server = unittest.mock.Mock() - w.notify = unittest.mock.Mock() - - self.loop.run_until_complete(w._run()) - self.assertTrue(wsgi[1].close.called) - self.assertTrue(wsgi[2].close.called) - - def test_wrp(self): - conn = object() - tracking = {} - meth = unittest.mock.Mock() - wrp = worker._wrp(conn, meth, tracking) - wrp() - - self.assertIn(id(conn), tracking) - self.assertTrue(meth.called) - - meth = unittest.mock.Mock() - wrp = worker._wrp(conn, meth, tracking, False) - wrp() - - self.assertNotIn(1, tracking) - self.assertTrue(meth.called) + app = self.worker.wsgi = unittest.mock.Mock() + app.connections = [object()] + app.finish.return_value = asyncio.Future(loop=self.loop) + app.finish.return_value.set_result(1) + app.finish_connections.return_value = asyncio.Future(loop=self.loop) + app.finish_connections.return_value.set_result(1) + + self.loop.run_until_complete(self.worker.close()) + app.finish.assert_called_with() + app.finish_connections.assert_called_with(timeout=80.0) + srv.close.assert_called_with() + self.assertEqual(self.worker.servers, []) + + self.loop.run_until_complete(self.worker.close())