Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cooperative signal handling #1600

Merged
merged 14 commits into from
Mar 19, 2024
67 changes: 67 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

import asyncio
import contextlib
import signal
import sys
from typing import Callable, ContextManager, Generator

import pytest

from uvicorn.config import Config
from uvicorn.server import Server


# asyncio does NOT allow raising in signal handlers, so to detect
# raised signals raised a mutable `witness` receives the signal
@contextlib.contextmanager
def capture_signal_sync(sig: signal.Signals) -> Generator[list[int], None, None]:
"""Replace `sig` handling with a normal exception via `signal"""
witness: list[int] = []
original_handler = signal.signal(sig, lambda signum, frame: witness.append(signum))
yield witness
signal.signal(sig, original_handler)


@contextlib.contextmanager
def capture_signal_async(sig: signal.Signals) -> Generator[list[int], None, None]: # pragma: py-win32
"""Replace `sig` handling with a normal exception via `asyncio"""
witness: list[int] = []
original_handler = signal.getsignal(sig)
asyncio.get_running_loop().add_signal_handler(sig, witness.append, sig)
yield witness
signal.signal(sig, original_handler)


async def dummy_app(scope, receive, send): # pragma: py-win32
pass


if sys.platform == "win32":
signals = [signal.SIGBREAK]
signal_captures = [capture_signal_sync]
else:
signals = [signal.SIGTERM, signal.SIGINT]
signal_captures = [capture_signal_sync, capture_signal_async]


@pytest.mark.anyio
@pytest.mark.parametrize("exception_signal", signals)
@pytest.mark.parametrize("capture_signal", signal_captures)
async def test_server_interrupt(
exception_signal: signal.Signals, capture_signal: Callable[[signal.Signals], ContextManager[None]]
): # pragma: py-win32
"""Test interrupting a Server that is run explicitly inside asyncio"""

async def interrupt_running(srv: Server):
while not srv.started:
await asyncio.sleep(0.01)
signal.raise_signal(exception_signal)

server = Server(Config(app=dummy_app, loop="asyncio"))
asyncio.create_task(interrupt_running(server))
with capture_signal(exception_signal) as witness:
await server.serve()
assert witness
# set by the server's graceful exit handler
assert server.should_exit
39 changes: 25 additions & 14 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import logging
import os
import platform
Expand All @@ -11,7 +12,7 @@
import time
from email.utils import formatdate
from types import FrameType
from typing import TYPE_CHECKING, Sequence, Union
from typing import TYPE_CHECKING, Generator, Sequence, Union

import click

Expand Down Expand Up @@ -57,11 +58,17 @@ def __init__(self, config: Config) -> None:
self.force_exit = False
self.last_notified = 0.0

self._captured_signals: list[int] = []

def run(self, sockets: list[socket.socket] | None = None) -> None:
self.config.setup_event_loop()
return asyncio.run(self.serve(sockets=sockets))

async def serve(self, sockets: list[socket.socket] | None = None) -> None:
with self.capture_signals():
await self._serve(sockets)

async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
process_id = os.getpid()

config = self.config
Expand All @@ -70,8 +77,6 @@ async def serve(self, sockets: list[socket.socket] | None = None) -> None:

self.lifespan = config.lifespan_class(config)

self.install_signal_handlers()

message = "Started server process [%d]"
color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
logger.info(message, process_id, extra={"color_message": color_message})
Expand Down Expand Up @@ -302,22 +307,28 @@ async def _wait_tasks_to_complete(self) -> None:
for server in self.servers:
await server.wait_closed()

def install_signal_handlers(self) -> None:
@contextlib.contextmanager
def capture_signals(self) -> Generator[None, None, None]:
# Signals can only be listened to from the main thread.
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
yield
return

loop = asyncio.get_event_loop()

# always use signal.signal, even if loop.add_signal_handler is available
# this allows to restore previous signal handlers later on
original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
Comment on lines +316 to +318
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we use loop.add_signal_handler before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't found a technical reason for it. It seems to have been added in #141 for no specific and then copied over the years. I assume it was used because the asyncio docs promote it as a Linux feature that is better in unspecified ways.

On the technical side the only advantage of loop.add_signal_handler is that it allows synchronously triggering async code (e.g. event.set()), and the handlers don't use that. Since the code has to be compatible with Windows, handlers cannot rely on being invoked by the loop anyways.

try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError: # pragma: no cover
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)
yield
finally:
for sig, handler in original_handlers.items():
signal.signal(sig, handler)
# If we did gracefully shut down due to a signal, try to
# trigger the expected behaviour now; multiple signals would be
# done LIFO, see https://stackoverflow.com/questions/48434964
for captured_signal in reversed(self._captured_signals):
signal.raise_signal(captured_signal)

def handle_exit(self, sig: int, frame: FrameType | None) -> None:
self._captured_signals.append(sig)
if self.should_exit and sig == signal.SIGINT:
self.force_exit = True
else:
Expand Down