Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport #5572: Use new loop for web.run_app(). #5820

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/5572.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Always create a new event loop in ``aiohttp.web.run_app()``.
This adds better compatibility with ``asyncio.run()`` or if trying to run multiple apps in sequence.
46 changes: 25 additions & 21 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,11 @@ def run_app(
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Run an app locally"""
loop = asyncio.get_event_loop()
if loop is None:
loop = asyncio.new_event_loop()

# Configure if and only if in debugging mode and using the default logger
if loop.get_debug() and access_log and access_log.name == "aiohttp.access":
Expand All @@ -488,27 +490,29 @@ def run_app(
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())

try:
main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
)

try:
asyncio.set_event_loop(loop)
loop.run_until_complete(main_task)
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
Expand Down
92 changes: 66 additions & 26 deletions tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_run_app_http(patched_loop) -> None:
cleanup_handler = make_mocked_coro()
app.on_cleanup.append(cleanup_handler)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -105,7 +105,7 @@ def test_run_app_http(patched_loop) -> None:

def test_run_app_close_loop(patched_loop) -> None:
app = web.Application()
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_run_app_mixed_bindings(
run_app_kwargs, expected_server_calls, expected_unix_server_calls, patched_loop
):
app = web.Application()
web.run_app(app, print=stopper(patched_loop), **run_app_kwargs)
web.run_app(app, print=stopper(patched_loop), **run_app_kwargs, loop=patched_loop)

assert patched_loop.create_unix_server.mock_calls == expected_unix_server_calls
assert patched_loop.create_server.mock_calls == expected_server_calls
Expand All @@ -435,7 +435,9 @@ def test_run_app_https(patched_loop) -> None:
app = web.Application()

ssl_context = ssl.create_default_context()
web.run_app(app, ssl_context=ssl_context, print=stopper(patched_loop))
web.run_app(
app, ssl_context=ssl_context, print=stopper(patched_loop), loop=patched_loop
)

patched_loop.create_server.assert_called_with(
mock.ANY,
Expand All @@ -453,7 +455,9 @@ def test_run_app_nondefault_host_port(patched_loop, aiohttp_unused_port) -> None
host = "127.0.0.1"

app = web.Application()
web.run_app(app, host=host, port=port, print=stopper(patched_loop))
web.run_app(
app, host=host, port=port, print=stopper(patched_loop), loop=patched_loop
)

patched_loop.create_server.assert_called_with(
mock.ANY, host, port, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -464,7 +468,7 @@ def test_run_app_multiple_hosts(patched_loop) -> None:
hosts = ("127.0.0.1", "127.0.0.2")

app = web.Application()
web.run_app(app, host=hosts, print=stopper(patched_loop))
web.run_app(app, host=hosts, print=stopper(patched_loop), loop=patched_loop)

calls = map(
lambda h: mock.call(
Expand All @@ -483,7 +487,7 @@ def test_run_app_multiple_hosts(patched_loop) -> None:

def test_run_app_custom_backlog(patched_loop) -> None:
app = web.Application()
web.run_app(app, backlog=10, print=stopper(patched_loop))
web.run_app(app, backlog=10, print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None
Expand All @@ -492,7 +496,13 @@ def test_run_app_custom_backlog(patched_loop) -> None:

def test_run_app_custom_backlog_unix(patched_loop) -> None:
app = web.Application()
web.run_app(app, path="/tmp/tmpsock.sock", backlog=10, print=stopper(patched_loop))
web.run_app(
app,
path="/tmp/tmpsock.sock",
backlog=10,
print=stopper(patched_loop),
loop=patched_loop,
)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, "/tmp/tmpsock.sock", ssl=None, backlog=10
Expand All @@ -505,7 +515,7 @@ def test_run_app_http_unix_socket(patched_loop, shorttmpdir) -> None:

sock_path = str(shorttmpdir / "socket.sock")
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, path=sock_path, print=printer)
web.run_app(app, path=sock_path, print=printer, loop=patched_loop)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=None, backlog=128
Expand All @@ -520,7 +530,9 @@ def test_run_app_https_unix_socket(patched_loop, shorttmpdir) -> None:
sock_path = str(shorttmpdir / "socket.sock")
ssl_context = ssl.create_default_context()
printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, path=sock_path, ssl_context=ssl_context, print=printer)
web.run_app(
app, path=sock_path, ssl_context=ssl_context, print=printer, loop=patched_loop
)

patched_loop.create_unix_server.assert_called_with(
mock.ANY, sock_path, ssl=ssl_context, backlog=128
Expand All @@ -534,7 +546,10 @@ def test_run_app_abstract_linux_socket(patched_loop) -> None:
sock_path = b"\x00" + uuid4().hex.encode("ascii")
app = web.Application()
web.run_app(
app, path=sock_path.decode("ascii", "ignore"), print=stopper(patched_loop)
app,
path=sock_path.decode("ascii", "ignore"),
print=stopper(patched_loop),
loop=patched_loop,
)

patched_loop.create_unix_server.assert_called_with(
Expand All @@ -551,7 +566,7 @@ def test_run_app_preexisting_inet_socket(patched_loop, mocker) -> None:
_, port = sock.getsockname()

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -569,7 +584,7 @@ def test_run_app_preexisting_inet6_socket(patched_loop) -> None:
port = sock.getsockname()[1]

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -588,7 +603,7 @@ def test_run_app_preexisting_unix_socket(patched_loop, mocker) -> None:
os.unlink(sock_path)

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=sock, print=printer)
web.run_app(app, sock=sock, print=printer, loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, sock=sock, backlog=128, ssl=None
Expand All @@ -608,7 +623,7 @@ def test_run_app_multiple_preexisting_sockets(patched_loop) -> None:
_, port2 = sock2.getsockname()

printer = mock.Mock(wraps=stopper(patched_loop))
web.run_app(app, sock=(sock1, sock2), print=printer)
web.run_app(app, sock=(sock1, sock2), print=printer, loop=patched_loop)

patched_loop.create_server.assert_has_calls(
[
Expand Down Expand Up @@ -664,7 +679,7 @@ def test_startup_cleanup_signals_even_on_failure(patched_loop) -> None:
app.on_cleanup.append(cleanup_handler)

with pytest.raises(RuntimeError):
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)

startup_handler.assert_called_once_with(app)
cleanup_handler.assert_called_once_with(app)
Expand All @@ -682,7 +697,7 @@ async def make_app():
app.on_cleanup.append(cleanup_handler)
return app

web.run_app(make_app(), print=stopper(patched_loop))
web.run_app(make_app(), print=stopper(patched_loop), loop=patched_loop)

patched_loop.create_server.assert_called_with(
mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None
Expand All @@ -703,7 +718,12 @@ def test_run_app_default_logger(monkeypatch, patched_loop):
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_any_call(logging.DEBUG)
mock_logger.hasHandlers.assert_called_with()
assert isinstance(mock_logger.addHandler.call_args[0][0], logging.StreamHandler)
Expand All @@ -721,7 +741,12 @@ def test_run_app_default_logger_setup_requires_debug(patched_loop):
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_not_called()
mock_logger.addHandler.assert_not_called()
Expand All @@ -739,7 +764,12 @@ def test_run_app_default_logger_setup_requires_default_logger(patched_loop):
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_not_called()
mock_logger.addHandler.assert_not_called()
Expand All @@ -757,7 +787,12 @@ def test_run_app_default_logger_setup_only_if_unconfigured(patched_loop):
mock_logger.configure_mock(**attrs)

app = web.Application()
web.run_app(app, print=stopper(patched_loop), access_log=mock_logger)
web.run_app(
app,
print=stopper(patched_loop),
access_log=mock_logger,
loop=patched_loop,
)
mock_logger.setLevel.assert_not_called()
mock_logger.hasHandlers.assert_called_with()
mock_logger.addHandler.assert_not_called()
Expand All @@ -774,7 +809,7 @@ async def on_startup(app):

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.cancelled()


Expand All @@ -792,7 +827,7 @@ async def on_startup(app):

app.on_startup.append(on_startup)

web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.done()


Expand All @@ -818,7 +853,7 @@ async def on_startup(app):

exc_handler = mock.Mock()
patched_loop.set_exception_handler(exc_handler)
web.run_app(app, print=stopper(patched_loop))
web.run_app(app, print=stopper(patched_loop), loop=patched_loop)
assert task.done()

msg = {
Expand All @@ -839,7 +874,12 @@ def base_runner_init_spy(self, *args, **kwargs):

app = web.Application()
monkeypatch.setattr(BaseRunner, "__init__", base_runner_init_spy)
web.run_app(app, keepalive_timeout=new_timeout, print=stopper(patched_loop))
web.run_app(
app,
keepalive_timeout=new_timeout,
print=stopper(patched_loop),
loop=patched_loop,
)


@pytest.mark.skipif(not PY_37, reason="contextvars support is required")
Expand Down Expand Up @@ -871,5 +911,5 @@ async def init():
count += 1
return app

web.run_app(init(), print=stopper(patched_loop))
web.run_app(init(), print=stopper(patched_loop), loop=patched_loop)
assert count == 3
26 changes: 26 additions & 0 deletions tests/test_web_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import platform
import signal
import sys
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -162,3 +163,28 @@ async def mock_create_server(*args, **kwargs):
assert server is runner.server
assert host is None
assert port == 8080


@pytest.mark.skipif(sys.version_info < (3, 7), reason="Requires asyncio.run()")
def test_run_after_asyncio_run() -> None:
async def nothing():
pass

def spy():
spy.called = True

spy.called = False

async def shutdown():
spy()
raise web.GracefulExit()

# asyncio.run() creates a new loop and closes it.
asyncio.run(nothing())

app = web.Application()
# create_task() will delay the function until app is run.
app.on_startup.append(lambda a: asyncio.create_task(shutdown()))

web.run_app(app)
assert spy.called, "run_app() should work after asyncio.run()."