diff --git a/waitress/server.py b/waitress/server.py index 307c377b..63cd3b31 100644 --- a/waitress/server.py +++ b/waitress/server.py @@ -14,6 +14,7 @@ import os import os.path +import signal import socket import time @@ -148,7 +149,11 @@ def print_listen(self, format_str): # pragma: nocover print(format_str.format(*l)) + def handle_sigterm(self, signum, frame): + raise SystemExit + def run(self): + signal.signal(signal.SIGTERM, self.handle_sigterm) try: self.asyncore.loop( timeout=self.adj.asyncore_loop_timeout, @@ -315,7 +320,11 @@ def handle_accept(self): addr = self.fix_addr(addr) self.channel_class(self, conn, addr, self.adj, map=self._map) + def handle_sigterm(self, signum, frame): + raise SystemExit + def run(self): + signal.signal(signal.SIGTERM, self.handle_sigterm) try: self.asyncore.loop( timeout=self.adj.asyncore_loop_timeout, diff --git a/waitress/tests/test_server.py b/waitress/tests/test_server.py index 7cd63452..381ac25c 100644 --- a/waitress/tests/test_server.py +++ b/waitress/tests/test_server.py @@ -1,5 +1,9 @@ import errno +import os +import signal import socket +import threading +import time import unittest dummy_app = object() @@ -127,6 +131,25 @@ def test_run(self): inst.run() self.assertTrue(inst.task_dispatcher.was_shutdown) + def test_sigterm_exits_gracefully(self): + inst = self._makeOneWithMap(_start=False) + inst.asyncore = DummyBlockingAsyncore() + inst.task_dispatcher = DummyTaskDispatcher() + + pid = os.getpid() + + def trigger_signal(): + time.sleep(1) + os.kill(pid, signal.SIGTERM) + + thread = threading.Thread(target=trigger_signal) + thread.daemon = True + thread.start() + + inst.run() + + self.assertTrue(inst.task_dispatcher.was_shutdown) + def test_run_base_server(self): inst = self._makeOneWithMulti(_start=False) inst.asyncore = DummyAsyncore() @@ -465,6 +488,11 @@ class DummyAsyncore(object): def loop(self, timeout=30.0, use_poll=False, map=None, count=None): raise SystemExit +class DummyBlockingAsyncore(object): + def loop(self, timeout=30.0, use_poll=False, map=None, count=None): + while True: + pass + class DummyTrigger(object): def pull_trigger(self):