Skip to content

Commit

Permalink
Add a graceful shutdown period to allow tasks to complete. (#7188) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored Feb 11, 2023
1 parent 5998102 commit dc0a8d4
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGES/7188.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a graceful shutdown period which allows pending tasks to complete before the application's cleanup is called. The period can be adjusted with the ``shutdown_timeout`` parameter. -- by :user:`Dreamsorcerer`.
2 changes: 1 addition & 1 deletion aiohttp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def start_server(
return
self._loop = loop
self._ssl = kwargs.pop("ssl", None)
self.runner = await self._make_runner(**kwargs)
self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
await self.runner.setup()
if not self.port:
self.port = 0
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,4 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter
try:
return await self._sendfile(request, fobj, offset, count)
finally:
await loop.run_in_executor(None, fobj.close)
await asyncio.shield(loop.run_in_executor(None, fobj.close))
23 changes: 22 additions & 1 deletion aiohttp/web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import signal
import socket
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import Any, List, Optional, Set

from yarl import URL
Expand Down Expand Up @@ -74,11 +75,26 @@ async def stop(self) -> None:
# named pipes do not have wait_closed property
if hasattr(self._server, "wait_closed"):
await self._server.wait_closed()

# Wait for pending tasks for a given time limit.
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(
self._wait(asyncio.current_task()), timeout=self._shutdown_timeout
)

await self._runner.shutdown()
assert self._runner.server
await self._runner.server.shutdown(self._shutdown_timeout)
self._runner._unreg_site(self)

async def _wait(self, parent_task: Optional["asyncio.Task[object]"]) -> None:
exclude = self._runner.starting_tasks | {asyncio.current_task(), parent_task}
# TODO(PY38): while tasks := asyncio.all_tasks() - exclude:
tasks = asyncio.all_tasks() - exclude
while tasks:
await asyncio.wait(tasks)
tasks = asyncio.all_tasks() - exclude


class TCPSite(BaseSite):
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port")
Expand Down Expand Up @@ -241,7 +257,7 @@ async def start(self) -> None:


class BaseRunner(ABC):
__slots__ = ("_handle_signals", "_kwargs", "_server", "_sites")
__slots__ = ("starting_tasks", "_handle_signals", "_kwargs", "_server", "_sites")

def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None:
self._handle_signals = handle_signals
Expand Down Expand Up @@ -281,6 +297,11 @@ async def setup(self) -> None:
pass

self._server = await self._make_server()
# On shutdown we want to avoid waiting on tasks which run forever.
# It's very likely that all tasks which run forever will have been created by
# the time we have completed the application startup (in self._make_server()),
# so we just record all running tasks here and exclude them later.
self.starting_tasks = asyncio.all_tasks()

@abstractmethod
async def shutdown(self) -> None:
Expand Down
10 changes: 8 additions & 2 deletions docs/web_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -923,8 +923,14 @@ Graceful shutdown
Stopping *aiohttp web server* by just closing all connections is not
always satisfactory.

The problem is: if application supports :term:`websocket`\s or *data
streaming* it most likely has open connections at server
The first thing aiohttp will do is to stop listening on the sockets,
so new connections will be rejected. It will then wait a few
seconds to allow any pending tasks to complete before continuing
with application shutdown. The timeout can be adjusted with
``shutdown_timeout`` in :func:`run_app`.

Another problem is if the application supports :term:`websockets <websocket>` or
*data streaming* it most likely has open connections at server
shutdown time.

The *library* has no knowledge how to close them gracefully but
Expand Down
34 changes: 21 additions & 13 deletions docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2777,9 +2777,10 @@ application on specific TCP or Unix socket, e.g.::

:param int port: PORT to listed on, ``8080`` if ``None`` (default).

:param float shutdown_timeout: a timeout for closing opened
connections on :meth:`BaseSite.stop`
call.
:param float shutdown_timeout: a timeout used for both waiting on pending
tasks before application shutdown and for
closing opened connections on
:meth:`BaseSite.stop` call.

:param ssl_context: a :class:`ssl.SSLContext` instance for serving
SSL/TLS secure server, ``None`` for plain HTTP
Expand Down Expand Up @@ -2812,9 +2813,10 @@ application on specific TCP or Unix socket, e.g.::

:param str path: PATH to UNIX socket to listen.

:param float shutdown_timeout: a timeout for closing opened
connections on :meth:`BaseSite.stop`
call.
:param float shutdown_timeout: a timeout used for both waiting on pending
tasks before application shutdown and for
closing opened connections on
:meth:`BaseSite.stop` call.

:param ssl_context: a :class:`ssl.SSLContext` instance for serving
SSL/TLS secure server, ``None`` for plain HTTP
Expand All @@ -2834,9 +2836,10 @@ application on specific TCP or Unix socket, e.g.::

:param str path: PATH of named pipe to listen.

:param float shutdown_timeout: a timeout for closing opened
connections on :meth:`BaseSite.stop`
call.
:param float shutdown_timeout: a timeout used for both waiting on pending
tasks before application shutdown and for
closing opened connections on
:meth:`BaseSite.stop` call.

.. class:: SockSite(runner, sock, *, \
shutdown_timeout=60.0, ssl_context=None, \
Expand All @@ -2848,9 +2851,10 @@ application on specific TCP or Unix socket, e.g.::

:param sock: A :ref:`socket instance <socket-objects>` to listen to.

:param float shutdown_timeout: a timeout for closing opened
connections on :meth:`BaseSite.stop`
call.
:param float shutdown_timeout: a timeout used for both waiting on pending
tasks before application shutdown and for
closing opened connections on
:meth:`BaseSite.stop` call.

:param ssl_context: a :class:`ssl.SSLContext` instance for serving
SSL/TLS secure server, ``None`` for plain HTTP
Expand Down Expand Up @@ -2944,9 +2948,13 @@ Utilities
shutdown before disconnecting all
open client sockets hard way.

This is used as a delay to wait for
pending tasks to complete and then
again to close any pending connections.

A system with properly
:ref:`aiohttp-web-graceful-shutdown`
implemented never waits for this
implemented never waits for the second
timeout but closes a server in a few
milliseconds.

Expand Down
198 changes: 197 additions & 1 deletion tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import ssl
import subprocess
import sys
import time
from typing import Callable, NoReturn
from unittest import mock
from uuid import uuid4

import pytest
from conftest import IS_UNIX, needs_unix

from aiohttp import web
from aiohttp import ClientConnectorError, ClientSession, web
from aiohttp.test_utils import make_mocked_coro
from aiohttp.web_runner import BaseRunner

Expand Down Expand Up @@ -910,3 +912,197 @@ async def init():

web.run_app(init(), print=stopper(patched_loop), loop=patched_loop)
assert count == 3


class TestShutdown:
def raiser(self) -> NoReturn:
raise KeyboardInterrupt

async def stop(self, request: web.Request) -> web.Response:
asyncio.get_running_loop().call_soon(self.raiser)
return web.Response()

def run_app(self, port: int, timeout: int, task, extra_test=None) -> asyncio.Task:
async def test() -> None:
await asyncio.sleep(1)
async with ClientSession() as sess:
async with sess.get(f"http://localhost:{port}/"):
pass
async with sess.get(f"http://localhost:{port}/stop"):
pass

if extra_test:
await extra_test(sess)

async def run_test(app: web.Application) -> None:
nonlocal test_task
test_task = asyncio.create_task(test())
yield
await test_task

async def handler(request: web.Request) -> web.Response:
nonlocal t
t = asyncio.create_task(task())
return web.Response(text="FOO")

t = test_task = None
app = web.Application()
app.cleanup_ctx.append(run_test)
app.router.add_get("/", handler)
app.router.add_get("/stop", self.stop)

web.run_app(app, port=port, shutdown_timeout=timeout)
assert test_task.exception() is None
return t

def test_shutdown_wait_for_task(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def task():
nonlocal finished
await asyncio.sleep(2)
finished = True

t = self.run_app(port, 3, task)

assert finished is True
assert t.done()
assert not t.cancelled()

def test_shutdown_timeout_task(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def task():
nonlocal finished
await asyncio.sleep(2)
finished = True

t = self.run_app(port, 1, task)

assert finished is False
assert t.done()
assert t.cancelled()

def test_shutdown_wait_for_spawned_task(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False
finished_sub = False
sub_t = None

async def sub_task():
nonlocal finished_sub
await asyncio.sleep(1.5)
finished_sub = True

async def task():
nonlocal finished, sub_t
await asyncio.sleep(0.5)
sub_t = asyncio.create_task(sub_task())
finished = True

t = self.run_app(port, 3, task)

assert finished is True
assert t.done()
assert not t.cancelled()
assert finished_sub is True
assert sub_t.done()
assert not sub_t.cancelled()

def test_shutdown_timeout_not_reached(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def task():
nonlocal finished
await asyncio.sleep(1)
finished = True

start_time = time.time()
t = self.run_app(port, 15, task)

assert finished is True
assert t.done()
# Verify run_app has not waited for timeout.
assert time.time() - start_time < 10

def test_shutdown_new_conn_rejected(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def task() -> None:
nonlocal finished
await asyncio.sleep(9)
finished = True

async def test(sess: ClientSession) -> None:
# Ensure we are in the middle of shutdown (waiting for task()).
await asyncio.sleep(1)
with pytest.raises(ClientConnectorError):
# Use a new session to try and open a new connection.
async with ClientSession() as sess:
async with sess.get(f"http://localhost:{port}/"):
pass
assert finished is False

t = self.run_app(port, 10, task, test)

assert finished is True
assert t.done()

def test_shutdown_pending_handler_responds(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False

async def test() -> None:
async def test_resp(sess):
async with sess.get(f"http://localhost:{port}/") as resp:
assert await resp.text() == "FOO"

await asyncio.sleep(1)
async with ClientSession() as sess:
t = asyncio.create_task(test_resp(sess))
await asyncio.sleep(1)
# Handler is in-progress while we trigger server shutdown.
async with sess.get(f"http://localhost:{port}/stop"):
pass

assert finished is False
# Handler should still complete and produce a response.
await t

async def run_test(app: web.Application) -> None:
nonlocal t
t = asyncio.create_task(test())
yield
await t

async def handler(request: web.Request) -> web.Response:
nonlocal finished
await asyncio.sleep(3)
finished = True
return web.Response(text="FOO")

t = None
app = web.Application()
app.cleanup_ctx.append(run_test)
app.router.add_get("/", handler)
app.router.add_get("/stop", self.stop)

web.run_app(app, port=port, shutdown_timeout=5)
assert t.exception() is None
assert finished is True

0 comments on commit dc0a8d4

Please sign in to comment.