From a3d9576f46e51ab483c5dfa9dfb4c6d7c4ca871a Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Mon, 21 Jun 2021 16:32:04 +0100 Subject: [PATCH 1/3] Use new loop for web.run_app(). --- CHANGES/5572.feature | 2 + aiohttp/web.py | 46 +++++++++++--------- tests/test_run_app.py | 92 ++++++++++++++++++++++++++++------------ tests/test_web_runner.py | 24 +++++++++++ 4 files changed, 117 insertions(+), 47 deletions(-) create mode 100644 CHANGES/5572.feature diff --git a/CHANGES/5572.feature b/CHANGES/5572.feature new file mode 100644 index 00000000000..a5d60fb6ee3 --- /dev/null +++ b/CHANGES/5572.feature @@ -0,0 +1,2 @@ +Always create a new event loop in ``aiohttp.web.run_app()``. +This adds better compatibility with ``asyncio.run()`` or if trying to run multiple apps in sequence. diff --git a/aiohttp/web.py b/aiohttp/web.py index 5c7518f00ee..b20957e485c 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -477,9 +477,11 @@ def run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Run an app locally""" - loop = asyncio.get_event_loop() + if loop is None: + loop = asyncio.new_event_loop() # Configure if and only if in debugging mode and using the default logger if loop.get_debug() and access_log and access_log.name == "aiohttp.access": @@ -488,27 +490,29 @@ def run_app( if not access_log.hasHandlers(): access_log.addHandler(logging.StreamHandler()) - try: - main_task = loop.create_task( - _run_app( - app, - host=host, - port=port, - path=path, - sock=sock, - shutdown_timeout=shutdown_timeout, - keepalive_timeout=keepalive_timeout, - ssl_context=ssl_context, - print=print, - backlog=backlog, - access_log_class=access_log_class, - access_log_format=access_log_format, - access_log=access_log, - handle_signals=handle_signals, - reuse_address=reuse_address, - reuse_port=reuse_port, - ) + main_task = loop.create_task( + _run_app( + app, + host=host, + port=port, + path=path, + sock=sock, + shutdown_timeout=shutdown_timeout, + keepalive_timeout=keepalive_timeout, + ssl_context=ssl_context, + print=print, + backlog=backlog, + access_log_class=access_log_class, + access_log_format=access_log_format, + access_log=access_log, + handle_signals=handle_signals, + reuse_address=reuse_address, + reuse_port=reuse_port, ) + ) + + try: + asyncio.set_event_loop(loop) loop.run_until_complete(main_task) except (GracefulExit, KeyboardInterrupt): # pragma: no cover pass diff --git a/tests/test_run_app.py b/tests/test_run_app.py index 74e951cd11a..e03a5fd6c90 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -94,7 +94,7 @@ def test_run_app_http(patched_loop) -> None: cleanup_handler = make_mocked_coro() app.on_cleanup.append(cleanup_handler) - web.run_app(app, print=stopper(patched_loop)) + web.run_app(app, print=stopper(patched_loop), loop=patched_loop) patched_loop.create_server.assert_called_with( mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None @@ -105,7 +105,7 @@ def test_run_app_http(patched_loop) -> None: def test_run_app_close_loop(patched_loop) -> None: app = web.Application() - web.run_app(app, print=stopper(patched_loop)) + web.run_app(app, print=stopper(patched_loop), loop=patched_loop) patched_loop.create_server.assert_called_with( mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None @@ -425,7 +425,7 @@ def test_run_app_mixed_bindings( run_app_kwargs, expected_server_calls, expected_unix_server_calls, patched_loop ): app = web.Application() - web.run_app(app, print=stopper(patched_loop), **run_app_kwargs) + web.run_app(app, print=stopper(patched_loop), **run_app_kwargs, loop=patched_loop) assert patched_loop.create_unix_server.mock_calls == expected_unix_server_calls assert patched_loop.create_server.mock_calls == expected_server_calls @@ -435,7 +435,9 @@ def test_run_app_https(patched_loop) -> None: app = web.Application() ssl_context = ssl.create_default_context() - web.run_app(app, ssl_context=ssl_context, print=stopper(patched_loop)) + web.run_app( + app, ssl_context=ssl_context, print=stopper(patched_loop), loop=patched_loop + ) patched_loop.create_server.assert_called_with( mock.ANY, @@ -453,7 +455,9 @@ def test_run_app_nondefault_host_port(patched_loop, aiohttp_unused_port) -> None host = "127.0.0.1" app = web.Application() - web.run_app(app, host=host, port=port, print=stopper(patched_loop)) + web.run_app( + app, host=host, port=port, print=stopper(patched_loop), loop=patched_loop + ) patched_loop.create_server.assert_called_with( mock.ANY, host, port, ssl=None, backlog=128, reuse_address=None, reuse_port=None @@ -464,7 +468,7 @@ def test_run_app_multiple_hosts(patched_loop) -> None: hosts = ("127.0.0.1", "127.0.0.2") app = web.Application() - web.run_app(app, host=hosts, print=stopper(patched_loop)) + web.run_app(app, host=hosts, print=stopper(patched_loop), loop=patched_loop) calls = map( lambda h: mock.call( @@ -483,7 +487,7 @@ def test_run_app_multiple_hosts(patched_loop) -> None: def test_run_app_custom_backlog(patched_loop) -> None: app = web.Application() - web.run_app(app, backlog=10, print=stopper(patched_loop)) + web.run_app(app, backlog=10, print=stopper(patched_loop), loop=patched_loop) patched_loop.create_server.assert_called_with( mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None @@ -492,7 +496,13 @@ def test_run_app_custom_backlog(patched_loop) -> None: def test_run_app_custom_backlog_unix(patched_loop) -> None: app = web.Application() - web.run_app(app, path="/tmp/tmpsock.sock", backlog=10, print=stopper(patched_loop)) + web.run_app( + app, + path="/tmp/tmpsock.sock", + backlog=10, + print=stopper(patched_loop), + loop=patched_loop, + ) patched_loop.create_unix_server.assert_called_with( mock.ANY, "/tmp/tmpsock.sock", ssl=None, backlog=10 @@ -505,7 +515,7 @@ def test_run_app_http_unix_socket(patched_loop, shorttmpdir) -> None: sock_path = str(shorttmpdir / "socket.sock") printer = mock.Mock(wraps=stopper(patched_loop)) - web.run_app(app, path=sock_path, print=printer) + web.run_app(app, path=sock_path, print=printer, loop=patched_loop) patched_loop.create_unix_server.assert_called_with( mock.ANY, sock_path, ssl=None, backlog=128 @@ -520,7 +530,9 @@ def test_run_app_https_unix_socket(patched_loop, shorttmpdir) -> None: sock_path = str(shorttmpdir / "socket.sock") ssl_context = ssl.create_default_context() printer = mock.Mock(wraps=stopper(patched_loop)) - web.run_app(app, path=sock_path, ssl_context=ssl_context, print=printer) + web.run_app( + app, path=sock_path, ssl_context=ssl_context, print=printer, loop=patched_loop + ) patched_loop.create_unix_server.assert_called_with( mock.ANY, sock_path, ssl=ssl_context, backlog=128 @@ -534,7 +546,10 @@ def test_run_app_abstract_linux_socket(patched_loop) -> None: sock_path = b"\x00" + uuid4().hex.encode("ascii") app = web.Application() web.run_app( - app, path=sock_path.decode("ascii", "ignore"), print=stopper(patched_loop) + app, + path=sock_path.decode("ascii", "ignore"), + print=stopper(patched_loop), + loop=patched_loop, ) patched_loop.create_unix_server.assert_called_with( @@ -551,7 +566,7 @@ def test_run_app_preexisting_inet_socket(patched_loop, mocker) -> None: _, port = sock.getsockname() printer = mock.Mock(wraps=stopper(patched_loop)) - web.run_app(app, sock=sock, print=printer) + web.run_app(app, sock=sock, print=printer, loop=patched_loop) patched_loop.create_server.assert_called_with( mock.ANY, sock=sock, backlog=128, ssl=None @@ -569,7 +584,7 @@ def test_run_app_preexisting_inet6_socket(patched_loop) -> None: port = sock.getsockname()[1] printer = mock.Mock(wraps=stopper(patched_loop)) - web.run_app(app, sock=sock, print=printer) + web.run_app(app, sock=sock, print=printer, loop=patched_loop) patched_loop.create_server.assert_called_with( mock.ANY, sock=sock, backlog=128, ssl=None @@ -588,7 +603,7 @@ def test_run_app_preexisting_unix_socket(patched_loop, mocker) -> None: os.unlink(sock_path) printer = mock.Mock(wraps=stopper(patched_loop)) - web.run_app(app, sock=sock, print=printer) + web.run_app(app, sock=sock, print=printer, loop=patched_loop) patched_loop.create_server.assert_called_with( mock.ANY, sock=sock, backlog=128, ssl=None @@ -608,7 +623,7 @@ def test_run_app_multiple_preexisting_sockets(patched_loop) -> None: _, port2 = sock2.getsockname() printer = mock.Mock(wraps=stopper(patched_loop)) - web.run_app(app, sock=(sock1, sock2), print=printer) + web.run_app(app, sock=(sock1, sock2), print=printer, loop=patched_loop) patched_loop.create_server.assert_has_calls( [ @@ -664,7 +679,7 @@ def test_startup_cleanup_signals_even_on_failure(patched_loop) -> None: app.on_cleanup.append(cleanup_handler) with pytest.raises(RuntimeError): - web.run_app(app, print=stopper(patched_loop)) + web.run_app(app, print=stopper(patched_loop), loop=patched_loop) startup_handler.assert_called_once_with(app) cleanup_handler.assert_called_once_with(app) @@ -682,7 +697,7 @@ async def make_app(): app.on_cleanup.append(cleanup_handler) return app - web.run_app(make_app(), print=stopper(patched_loop)) + web.run_app(make_app(), print=stopper(patched_loop), loop=patched_loop) patched_loop.create_server.assert_called_with( mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None @@ -703,7 +718,12 @@ def test_run_app_default_logger(monkeypatch, patched_loop): mock_logger.configure_mock(**attrs) app = web.Application() - web.run_app(app, print=stopper(patched_loop), access_log=mock_logger) + web.run_app( + app, + print=stopper(patched_loop), + access_log=mock_logger, + loop=patched_loop, + ) mock_logger.setLevel.assert_any_call(logging.DEBUG) mock_logger.hasHandlers.assert_called_with() assert isinstance(mock_logger.addHandler.call_args[0][0], logging.StreamHandler) @@ -721,7 +741,12 @@ def test_run_app_default_logger_setup_requires_debug(patched_loop): mock_logger.configure_mock(**attrs) app = web.Application() - web.run_app(app, print=stopper(patched_loop), access_log=mock_logger) + web.run_app( + app, + print=stopper(patched_loop), + access_log=mock_logger, + loop=patched_loop, + ) mock_logger.setLevel.assert_not_called() mock_logger.hasHandlers.assert_not_called() mock_logger.addHandler.assert_not_called() @@ -739,7 +764,12 @@ def test_run_app_default_logger_setup_requires_default_logger(patched_loop): mock_logger.configure_mock(**attrs) app = web.Application() - web.run_app(app, print=stopper(patched_loop), access_log=mock_logger) + web.run_app( + app, + print=stopper(patched_loop), + access_log=mock_logger, + loop=patched_loop, + ) mock_logger.setLevel.assert_not_called() mock_logger.hasHandlers.assert_not_called() mock_logger.addHandler.assert_not_called() @@ -757,7 +787,12 @@ def test_run_app_default_logger_setup_only_if_unconfigured(patched_loop): mock_logger.configure_mock(**attrs) app = web.Application() - web.run_app(app, print=stopper(patched_loop), access_log=mock_logger) + web.run_app( + app, + print=stopper(patched_loop), + access_log=mock_logger, + loop=patched_loop, + ) mock_logger.setLevel.assert_not_called() mock_logger.hasHandlers.assert_called_with() mock_logger.addHandler.assert_not_called() @@ -774,7 +809,7 @@ async def on_startup(app): app.on_startup.append(on_startup) - web.run_app(app, print=stopper(patched_loop)) + web.run_app(app, print=stopper(patched_loop), loop=patched_loop) assert task.cancelled() @@ -792,7 +827,7 @@ async def on_startup(app): app.on_startup.append(on_startup) - web.run_app(app, print=stopper(patched_loop)) + web.run_app(app, print=stopper(patched_loop), loop=patched_loop) assert task.done() @@ -818,7 +853,7 @@ async def on_startup(app): exc_handler = mock.Mock() patched_loop.set_exception_handler(exc_handler) - web.run_app(app, print=stopper(patched_loop)) + web.run_app(app, print=stopper(patched_loop), loop=patched_loop) assert task.done() msg = { @@ -839,7 +874,12 @@ def base_runner_init_spy(self, *args, **kwargs): app = web.Application() monkeypatch.setattr(BaseRunner, "__init__", base_runner_init_spy) - web.run_app(app, keepalive_timeout=new_timeout, print=stopper(patched_loop)) + web.run_app( + app, + keepalive_timeout=new_timeout, + print=stopper(patched_loop), + loop=patched_loop, + ) @pytest.mark.skipif(not PY_37, reason="contextvars support is required") @@ -871,5 +911,5 @@ async def init(): count += 1 return app - web.run_app(init(), print=stopper(patched_loop)) + web.run_app(init(), print=stopper(patched_loop), loop=patched_loop) assert count == 3 diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index af6df1aa8e0..5c45b234c55 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -162,3 +162,27 @@ async def mock_create_server(*args, **kwargs): assert server is runner.server assert host is None assert port == 8080 + + +def test_run_after_asyncio_run() -> None: + async def nothing(): + pass + + def spy(): + spy.called = True + + spy.called = False + + async def shutdown(): + spy() + raise web.GracefulExit() + + # asyncio.run() creates a new loop and closes it. + asyncio.run(nothing()) + + app = web.Application() + # create_task() will delay the function until app is run. + app.on_startup.append(lambda a: asyncio.create_task(shutdown())) + + web.run_app(app) + assert spy.called, "run_app() should work after asyncio.run()." From f35596b512c520a395154d2adcf557a39e167b1e Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Mon, 21 Jun 2021 19:45:40 +0100 Subject: [PATCH 2/3] Skip test on 3.6 --- tests/test_web_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index 5c45b234c55..9194eb97670 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -164,6 +164,7 @@ async def mock_create_server(*args, **kwargs): assert port == 8080 +@pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires asyncio.run()") def test_run_after_asyncio_run() -> None: async def nothing(): pass From 38fd9caed668a6a47d1c3fce17b781189fbf2010 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Mon, 21 Jun 2021 19:53:00 +0100 Subject: [PATCH 3/3] Update test_web_runner.py --- tests/test_web_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index 9194eb97670..8c08a5f5fbd 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -1,6 +1,7 @@ import asyncio import platform import signal +import sys from unittest.mock import patch import pytest