forked from PrefectHQ/prefect
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle SIGTERM received by agent gracefully (PrefectHQ#8691)
Signed-off-by: ddelange <[email protected]> Co-authored-by: Zanie Adkins <[email protected]> Co-authored-by: Alexander Streed <[email protected]>
- Loading branch information
1 parent
461faaa
commit b7e5e42
Showing
4 changed files
with
208 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import os | ||
import signal | ||
import sys | ||
import tempfile | ||
|
||
import anyio | ||
import pytest | ||
|
||
from prefect.settings import get_current_settings | ||
from prefect.utilities.processutils import open_process | ||
|
||
POLL_INTERVAL = 0.5 | ||
STARTUP_TIMEOUT = 20 | ||
SHUTDOWN_TIMEOUT = 5 | ||
|
||
|
||
async def safe_shutdown(process): | ||
try: | ||
with anyio.fail_after(SHUTDOWN_TIMEOUT): | ||
await process.wait() | ||
except TimeoutError: | ||
# try twice in case process.wait() hangs | ||
with anyio.fail_after(SHUTDOWN_TIMEOUT): | ||
await process.wait() | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
async def agent_process(use_hosted_api_server): | ||
""" | ||
Runs an agent listening to all queues. | ||
Yields: | ||
The anyio.Process. | ||
""" | ||
out = tempfile.TemporaryFile() # capture output for test assertions | ||
|
||
# Will connect to the same database as normal test clients | ||
async with open_process( | ||
command=[ | ||
"prefect", | ||
"agent", | ||
"start", | ||
"--match=nonexist", | ||
], | ||
stdout=out, | ||
stderr=out, | ||
env={**os.environ, **get_current_settings().to_environment_variables()}, | ||
) as process: | ||
process.out = out | ||
|
||
for _ in range(int(STARTUP_TIMEOUT / POLL_INTERVAL)): | ||
await anyio.sleep(POLL_INTERVAL) | ||
if out.tell() > 400: | ||
await anyio.sleep(2) | ||
break | ||
|
||
assert out.tell() > 400, "The agent did not start up in time" | ||
assert process.returncode is None, "The agent failed to start up" | ||
|
||
# Yield to the consuming tests | ||
yield process | ||
|
||
# Then shutdown the process | ||
try: | ||
process.terminate() | ||
except ProcessLookupError: | ||
pass | ||
out.close() | ||
|
||
|
||
class TestAgentSignalForwarding: | ||
@pytest.mark.skipif( | ||
sys.platform == "win32", | ||
reason="SIGTERM is only used in non-Windows environments", | ||
) | ||
async def test_sigint_sends_sigterm(self, agent_process): | ||
agent_process.send_signal(signal.SIGINT) | ||
await safe_shutdown(agent_process) | ||
agent_process.out.seek(0) | ||
out = agent_process.out.read().decode() | ||
|
||
assert "Sending SIGINT" in out, ( | ||
"When sending a SIGINT, the main process should receive a SIGINT." | ||
f" Output:\n{out}" | ||
) | ||
assert "Agent stopped!" in out, ( | ||
"When sending a SIGINT, the main process should shutdown gracefully." | ||
f" Output:\n{out}" | ||
) | ||
|
||
@pytest.mark.skipif( | ||
sys.platform == "win32", | ||
reason="SIGTERM is only used in non-Windows environments", | ||
) | ||
async def test_sigterm_sends_sigterm_directly(self, agent_process): | ||
agent_process.send_signal(signal.SIGTERM) | ||
await safe_shutdown(agent_process) | ||
agent_process.out.seek(0) | ||
out = agent_process.out.read().decode() | ||
|
||
assert "Sending SIGINT" in out, ( | ||
"When sending a SIGTERM, the main process should receive a SIGINT." | ||
f" Output:\n{out}" | ||
) | ||
assert "Agent stopped!" in out, ( | ||
"When sending a SIGTERM, the main process should shutdown gracefully." | ||
f" Output:\n{out}" | ||
) | ||
|
||
@pytest.mark.skipif( | ||
sys.platform == "win32", | ||
reason="SIGTERM is only used in non-Windows environments", | ||
) | ||
async def test_sigint_sends_sigterm_then_sigkill(self, agent_process): | ||
agent_process.send_signal(signal.SIGINT) | ||
await anyio.sleep(0.01) # some time needed for the recursive signal handler | ||
agent_process.send_signal(signal.SIGINT) | ||
await safe_shutdown(agent_process) | ||
agent_process.out.seek(0) | ||
out = agent_process.out.read().decode() | ||
|
||
assert ( | ||
# either the main PID is still waiting for shutdown, so forwards the SIGKILL | ||
"Sending SIGKILL" in out | ||
# or SIGKILL came too late, and the main PID is already closing | ||
or "KeyboardInterrupt" in out | ||
or "Agent stopped!" in out | ||
or "Aborted." in out | ||
), ( | ||
"When sending two SIGINT shortly after each other, the main process should" | ||
f" first receive a SIGINT and then a SIGKILL. Output:\n{out}" | ||
) | ||
|
||
@pytest.mark.skipif( | ||
sys.platform == "win32", | ||
reason="SIGTERM is only used in non-Windows environments", | ||
) | ||
async def test_sigterm_sends_sigterm_then_sigkill(self, agent_process): | ||
agent_process.send_signal(signal.SIGTERM) | ||
await anyio.sleep(0.01) # some time needed for the recursive signal handler | ||
agent_process.send_signal(signal.SIGTERM) | ||
await safe_shutdown(agent_process) | ||
agent_process.out.seek(0) | ||
out = agent_process.out.read().decode() | ||
|
||
assert ( | ||
# either the main PID is still waiting for shutdown, so forwards the SIGKILL | ||
"Sending SIGKILL" in out | ||
# or SIGKILL came too late, and the main PID is already closing | ||
or "KeyboardInterrupt" in out | ||
or "Agent stopped!" in out | ||
or "Aborted." in out | ||
), ( | ||
"When sending two SIGTERM shortly after each other, the main process should" | ||
f" first receive a SIGINT and then a SIGKILL. Output:\n{out}" | ||
) | ||
|
||
@pytest.mark.skipif( | ||
sys.platform != "win32", | ||
reason="CTRL_BREAK_EVENT is only defined in Windows", | ||
) | ||
async def test_sends_ctrl_break_win32(self, agent_process): | ||
agent_process.send_signal(signal.SIGINT) | ||
await safe_shutdown(agent_process) | ||
agent_process.out.seek(0) | ||
out = agent_process.out.read().decode() | ||
|
||
assert "Sending CTRL_BREAK_EVENT" in out, ( | ||
"When sending a SIGINT, the main process should send a CTRL_BREAK_EVENT to" | ||
f" the uvicorn subprocess. Output:\n{out}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters