From 7c85e5f55b8a92d70a26506a9f1628dc8d7547b0 Mon Sep 17 00:00:00 2001 From: DasSkelett Date: Sat, 6 Jan 2024 19:02:54 +0000 Subject: [PATCH] Make worker cleanup threads more robust, handle peers without handshake time --- README.md | 8 ++++---- wgkex/worker/app.py | 14 +++++++++----- wgkex/worker/app_test.py | 30 +++++++++++++++++++++++------- wgkex/worker/mqtt_test.py | 21 ++++++--------------- wgkex/worker/netlink.py | 9 ++++----- 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 9f673f8..911aeb8 100644 --- a/README.md +++ b/README.md @@ -218,10 +218,10 @@ sudo ip link set vx-welt up ### MQTT topics -Publishing keys broker->worker: `wireguard/{domain}/{worker}` -Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers` -Publishing worker status: `wireguard-worker/{worker}/status` -Publishing worker data: `wireguard-worker/{worker}/{domain}/data` +- Publishing keys broker->worker: `wireguard/{domain}/{worker}` +- Publishing metrics worker->broker: `wireguard-metrics/{domain}/{worker}/connected_peers` +- Publishing worker status: `wireguard-worker/{worker}/status` +- Publishing worker data: `wireguard-worker/{worker}/{domain}/data` ## Contact diff --git a/wgkex/worker/app.py b/wgkex/worker/app.py index 9a07d97..432955c 100644 --- a/wgkex/worker/app.py +++ b/wgkex/worker/app.py @@ -39,9 +39,14 @@ class InvalidDomain(Error): def flush_workers(domain: Text) -> None: """Calls peer flush every _CLEANUP_TIME interval.""" while True: - time.sleep(_CLEANUP_TIME) - logger.info(f"Running cleanup task for {domain}") - logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain)) + try: + time.sleep(_CLEANUP_TIME) + logger.info(f"Running cleanup task for {domain}") + logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain)) + except Exception as e: + # Don't crash the thread when an exception is encountered + logger.error(f"Exception during cleanup task for {domain}:") + logger.error(e) def clean_up_worker() -> None: @@ -100,8 +105,7 @@ def check_all_domains_unique(domains, prefixes): stripped_domain = domain.split(prefix)[1] if stripped_domain in unique_domains: logger.error( - "We have a non-unique domain here", - domain, + f"Domain {domain} is not unique after stripping the prefix" ) return False unique_domains.append(stripped_domain) diff --git a/wgkex/worker/app_test.py b/wgkex/worker/app_test.py index 0cf525f..04cc6fb 100644 --- a/wgkex/worker/app_test.py +++ b/wgkex/worker/app_test.py @@ -1,4 +1,6 @@ """Unit tests for app.py""" +import threading +from time import sleep import unittest import mock @@ -88,14 +90,28 @@ def test_main_fails_bad_domain(self, connect_mock, config_mock): app.main() connect_mock.assert_not_called() - @mock.patch("time.sleep", side_effect=InterruptedError) + @mock.patch.object(app, "_CLEANUP_TIME", 0) @mock.patch.object(app, "wg_flush_stale_peers") - def test_flush_workers(self, flush_mock, sleep_mock): - """Ensure we fail when domains are badly formatted.""" - flush_mock.return_value = "" - # Infinite loop in flush_workers has no exit value, so test will generate one, and test for that. - with self.assertRaises(InterruptedError): - app.flush_workers("test_domain") + def test_flush_workers_doesnt_throw(self, wg_flush_mock): + """Ensure the flush_workers thread doesn't throw and exit if it encounters an exception.""" + wg_flush_mock.side_effect = AttributeError( + "'NoneType' object has no attribute 'get'" + ) + + thread = threading.Thread( + target=app.flush_workers, args=("dummy_domain",), daemon=True + ) + thread.start() + + i = 0 + while i < 20 and not wg_flush_mock.called: + i += 1 + sleep(0.1) + + wg_flush_mock.assert_called() + # Assert that the thread hasn't crashed and is still running + self.assertTrue(thread.is_alive()) + # If Python would allow it without writing custom signalling, this would be the place to stop the thread again if __name__ == "__main__": diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index b17d1d6..8bd6672 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -91,29 +91,20 @@ def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock): self.assertFalse(thread.is_alive()) - -""" @mock.patch.object(msg_queue, "link_handler") @mock.patch.object(mqtt, "get_config") - def test_on_message_wireguard_success(self, config_mock, link_mock): + def test_on_message_wireguard_success(self, config_mock): # Tests on_message for success. config_mock.return_value = _get_config_mock() - link_mock.return_value = dict(WireGuard="result") mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage") mqtt_msg.topic = "wireguard/_ffmuc_domain1/gateway" mqtt_msg.payload = b"PUB_KEY" mqtt.on_message_wireguard(None, None, mqtt_msg) - link_mock.assert_has_calls( - [ - mock.call( - msg_queue.WireGuardClient( - public_key="PUB_KEY", domain="domain1", remove=False - ) - ) - ], - any_order=True, - ) + self.assertTrue(mqtt.q.qsize() > 0) + item = mqtt.q.get_nowait() + self.assertEqual(item, ("domain1", "PUB_KEY")) + - @mock.patch.object(msg_queue, "link_handler") +""" @mock.patch.object(msg_queue, "link_handler") @mock.patch.object(mqtt, "get_config") def test_on_message_wireguard_fails_no_domain(self, config_mock, link_mock): # Tests on_message for failure to parse domain. diff --git a/wgkex/worker/netlink.py b/wgkex/worker/netlink.py index a1b5411..366d430 100644 --- a/wgkex/worker/netlink.py +++ b/wgkex/worker/netlink.py @@ -72,13 +72,12 @@ def wg_flush_stale_peers(domain: str) -> List[Dict]: stale_clients = [ stale_client for stale_client in find_stale_wireguard_clients("wg-" + domain) ] - logger.debug("Found stale clients: %s", stale_clients) - logger.info("Searching for stale WireGuard clients.") + logger.debug("Found %s stale clients: %s", len(stale_clients), stale_clients) stale_wireguard_clients = [ WireGuardClient(public_key=stale_client, domain=domain, remove=True) for stale_client in stale_clients ] - logger.debug("Found stable WireGuard clients: %s", stale_wireguard_clients) + logger.debug("Found stale WireGuard clients: %s", stale_wireguard_clients) logger.info("Processing clients.") link_handled = [ link_handler(stale_client) for stale_client in stale_wireguard_clients @@ -205,8 +204,8 @@ def find_stale_wireguard_clients(wg_interface: str) -> List: ret = [ peer.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8") for peer in all_peers - if peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME").get("tv_sec", int()) - < three_hrs_in_secs + if (hshk_time := peer.get_attr("WGPEER_A_LAST_HANDSHAKE_TIME")) is not None + and hshk_time.get("tv_sec", int()) < three_hrs_in_secs ] return ret