Skip to content

Commit

Permalink
Handle SIGTERM received by agent gracefully (PrefectHQ#8691)
Browse files Browse the repository at this point in the history
Signed-off-by: ddelange <[email protected]>
Co-authored-by: Zanie Adkins <[email protected]>
Co-authored-by: Alexander Streed <[email protected]>
  • Loading branch information
3 people authored and Åsmund Østvold committed May 11, 2023
1 parent 461faaa commit b7e5e42
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/prefect/cli/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Command line interface for working with agent services
"""
import os
from functools import partial
from typing import List
from uuid import UUID
Expand All @@ -20,6 +21,7 @@
PREFECT_AGENT_QUERY_INTERVAL,
PREFECT_API_URL,
)
from prefect.utilities.processutils import setup_signal_handlers_agent
from prefect.utilities.services import critical_service_loop

agent_app = PrefectTyper(
Expand Down Expand Up @@ -161,6 +163,11 @@ async def start(
f"Starting v{prefect.__version__} agent with ephemeral API..."
)

agent_process_id = os.getpid()
setup_signal_handlers_agent(
agent_process_id, "the Prefect agent", app.console.print
)

async with PrefectAgent(
work_queues=work_queues,
work_queue_prefix=work_queue_prefix,
Expand Down
31 changes: 29 additions & 2 deletions src/prefect/utilities/processutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,20 @@ def forward_signal_handler(
"""Forward subsequent signum events (e.g. interrupts) to respective signums."""
current_signal, future_signals = signums[0], signums[1:]

# avoid RecursionError when setting up a direct signal forward to the same signal for the main pid
avoid_infinite_recursion = signum == current_signal and pid == os.getpid()
if avoid_infinite_recursion:
# store the vanilla handler so it can be temporarily restored below
original_handler = signal.getsignal(current_signal)

def handler(*args):
print_fn(
f"Received {getattr(signum, 'name', signum)}. "
f"Sending {getattr(current_signal, 'name', current_signal)} to"
f" {process_name} (PID {pid})..."
)
if avoid_infinite_recursion:
signal.signal(current_signal, original_handler)
os.kill(pid, current_signal)
if future_signals:
forward_signal_handler(
Expand All @@ -345,12 +353,31 @@ def setup_signal_handlers_server(pid: int, process_name: str, print_fn: Callable
setup_handler = partial(
forward_signal_handler, pid, process_name=process_name, print_fn=print_fn
)
# on Windows, use CTRL_BREAK_EVENT as SIGTERM is useless:
# https://bugs.python.org/issue26350
# when server receives a signal, it needs to be propagated to the uvicorn subprocess
if sys.platform == "win32":
# on Windows, use CTRL_BREAK_EVENT as SIGTERM is useless:
# https://bugs.python.org/issue26350
setup_handler(signal.SIGINT, signal.CTRL_BREAK_EVENT)
else:
# first interrupt: SIGTERM, second interrupt: SIGKILL
setup_handler(signal.SIGINT, signal.SIGTERM, signal.SIGKILL)
# forward first SIGTERM directly, send SIGKILL on subsequent SIGTERM
setup_handler(signal.SIGTERM, signal.SIGTERM, signal.SIGKILL)


def setup_signal_handlers_agent(pid: int, process_name: str, print_fn: Callable):
"""Handle interrupts of the agent gracefully."""
setup_handler = partial(
forward_signal_handler, pid, process_name=process_name, print_fn=print_fn
)
# when agent receives SIGINT, it stops dequeueing new FlowRuns, and runs until the subprocesses finish
# the signal is not forwarded to subprocesses, so they can continue to run and hopefully still complete
if sys.platform == "win32":
# on Windows, use CTRL_BREAK_EVENT as SIGTERM is useless:
# https://bugs.python.org/issue26350
setup_handler(signal.SIGINT, signal.CTRL_BREAK_EVENT)
else:
# forward first SIGINT directly, send SIGKILL on subsequent interrupt
setup_handler(signal.SIGINT, signal.SIGINT, signal.SIGKILL)
# first SIGTERM: send SIGINT, send SIGKILL on subsequent SIGTERM
setup_handler(signal.SIGTERM, signal.SIGINT, signal.SIGKILL)
170 changes: 170 additions & 0 deletions tests/cli/test_start_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import os
import signal
import sys
import tempfile

import anyio
import pytest

from prefect.settings import get_current_settings
from prefect.utilities.processutils import open_process

POLL_INTERVAL = 0.5
STARTUP_TIMEOUT = 20
SHUTDOWN_TIMEOUT = 5


async def safe_shutdown(process):
try:
with anyio.fail_after(SHUTDOWN_TIMEOUT):
await process.wait()
except TimeoutError:
# try twice in case process.wait() hangs
with anyio.fail_after(SHUTDOWN_TIMEOUT):
await process.wait()


@pytest.fixture(scope="function")
async def agent_process(use_hosted_api_server):
"""
Runs an agent listening to all queues.
Yields:
The anyio.Process.
"""
out = tempfile.TemporaryFile() # capture output for test assertions

# Will connect to the same database as normal test clients
async with open_process(
command=[
"prefect",
"agent",
"start",
"--match=nonexist",
],
stdout=out,
stderr=out,
env={**os.environ, **get_current_settings().to_environment_variables()},
) as process:
process.out = out

for _ in range(int(STARTUP_TIMEOUT / POLL_INTERVAL)):
await anyio.sleep(POLL_INTERVAL)
if out.tell() > 400:
await anyio.sleep(2)
break

assert out.tell() > 400, "The agent did not start up in time"
assert process.returncode is None, "The agent failed to start up"

# Yield to the consuming tests
yield process

# Then shutdown the process
try:
process.terminate()
except ProcessLookupError:
pass
out.close()


class TestAgentSignalForwarding:
@pytest.mark.skipif(
sys.platform == "win32",
reason="SIGTERM is only used in non-Windows environments",
)
async def test_sigint_sends_sigterm(self, agent_process):
agent_process.send_signal(signal.SIGINT)
await safe_shutdown(agent_process)
agent_process.out.seek(0)
out = agent_process.out.read().decode()

assert "Sending SIGINT" in out, (
"When sending a SIGINT, the main process should receive a SIGINT."
f" Output:\n{out}"
)
assert "Agent stopped!" in out, (
"When sending a SIGINT, the main process should shutdown gracefully."
f" Output:\n{out}"
)

@pytest.mark.skipif(
sys.platform == "win32",
reason="SIGTERM is only used in non-Windows environments",
)
async def test_sigterm_sends_sigterm_directly(self, agent_process):
agent_process.send_signal(signal.SIGTERM)
await safe_shutdown(agent_process)
agent_process.out.seek(0)
out = agent_process.out.read().decode()

assert "Sending SIGINT" in out, (
"When sending a SIGTERM, the main process should receive a SIGINT."
f" Output:\n{out}"
)
assert "Agent stopped!" in out, (
"When sending a SIGTERM, the main process should shutdown gracefully."
f" Output:\n{out}"
)

@pytest.mark.skipif(
sys.platform == "win32",
reason="SIGTERM is only used in non-Windows environments",
)
async def test_sigint_sends_sigterm_then_sigkill(self, agent_process):
agent_process.send_signal(signal.SIGINT)
await anyio.sleep(0.01) # some time needed for the recursive signal handler
agent_process.send_signal(signal.SIGINT)
await safe_shutdown(agent_process)
agent_process.out.seek(0)
out = agent_process.out.read().decode()

assert (
# either the main PID is still waiting for shutdown, so forwards the SIGKILL
"Sending SIGKILL" in out
# or SIGKILL came too late, and the main PID is already closing
or "KeyboardInterrupt" in out
or "Agent stopped!" in out
or "Aborted." in out
), (
"When sending two SIGINT shortly after each other, the main process should"
f" first receive a SIGINT and then a SIGKILL. Output:\n{out}"
)

@pytest.mark.skipif(
sys.platform == "win32",
reason="SIGTERM is only used in non-Windows environments",
)
async def test_sigterm_sends_sigterm_then_sigkill(self, agent_process):
agent_process.send_signal(signal.SIGTERM)
await anyio.sleep(0.01) # some time needed for the recursive signal handler
agent_process.send_signal(signal.SIGTERM)
await safe_shutdown(agent_process)
agent_process.out.seek(0)
out = agent_process.out.read().decode()

assert (
# either the main PID is still waiting for shutdown, so forwards the SIGKILL
"Sending SIGKILL" in out
# or SIGKILL came too late, and the main PID is already closing
or "KeyboardInterrupt" in out
or "Agent stopped!" in out
or "Aborted." in out
), (
"When sending two SIGTERM shortly after each other, the main process should"
f" first receive a SIGINT and then a SIGKILL. Output:\n{out}"
)

@pytest.mark.skipif(
sys.platform != "win32",
reason="CTRL_BREAK_EVENT is only defined in Windows",
)
async def test_sends_ctrl_break_win32(self, agent_process):
agent_process.send_signal(signal.SIGINT)
await safe_shutdown(agent_process)
agent_process.out.seek(0)
out = agent_process.out.read().decode()

assert "Sending CTRL_BREAK_EVENT" in out, (
"When sending a SIGINT, the main process should send a CTRL_BREAK_EVENT to"
f" the uvicorn subprocess. Output:\n{out}"
)
3 changes: 2 additions & 1 deletion tests/cli/test_start_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from prefect.testing.fixtures import is_port_in_use
from prefect.utilities.processutils import open_process

POLL_INTERVAL = 0.5
STARTUP_TIMEOUT = 20
SHUTDOWN_TIMEOUT = 20

Expand Down Expand Up @@ -70,7 +71,7 @@ async def server_process():
if response.status_code == 200:
await anyio.sleep(0.5) # extra sleep for less flakiness
break
await anyio.sleep(0.1)
await anyio.sleep(POLL_INTERVAL)
if response:
response.raise_for_status()
if not response:
Expand Down

0 comments on commit b7e5e42

Please sign in to comment.