diff --git a/wgkex/broker/BUILD b/wgkex/broker/BUILD index 7ee73ef..316c91a 100644 --- a/wgkex/broker/BUILD +++ b/wgkex/broker/BUILD @@ -12,12 +12,3 @@ py_binary( "//wgkex/config:config", ], ) - -py_test( - name="app_test", - srcs=["app_test.py"], - deps=[ - ":app", - requirement("mock"), - ], -) diff --git a/wgkex/broker/app_test.py b/wgkex/broker/app_test.py deleted file mode 100644 index aa40b25..0000000 --- a/wgkex/broker/app_test.py +++ /dev/null @@ -1,90 +0,0 @@ -import unittest -import mock -import app -import sys -from wgkex.config.config_test import _VALID_CFG -from wgkex.config.config_test import _INVALID_CFG - - -class TestApp(unittest.TestCase): - - # TODO(ruairi): Add test for Flask. - # def setUp(self) -> None: - # mock_open = mock.mock_open(read_data=_VALID_CFG) - # with mock.patch("builtins.open", mock_open): - # app_cfg = app.app.test_client() - # app.main() - # self.app_cfg = app_cfg - - def test_app_load_success(self): - """Tests _fetch_app_config success.""" - mock_open = mock.mock_open(read_data=_VALID_CFG) - with mock.patch("builtins.open", mock_open): - cfg = app._fetch_app_config() - self.assertIsNotNone(cfg) - - @mock.patch.object(sys, "exit", autospec=True) - def test_app_load_fails_bad_config(self, exit_mock): - """Tests _fetch_app_config fails with bad configuration.""" - mock_open = mock.mock_open(read_data=_INVALID_CFG) - with mock.patch("builtins.open", mock_open): - with self.assertRaises(TypeError): - app._fetch_app_config() - exit_mock.assert_called_with(2) - - def test_is_valid_wg_pubkey_success(self): - """Tests is_valid_wg_pubkey success.""" - key = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAE=" - self.assertEqual(key, app.is_valid_wg_pubkey(key)) - - def test_is_valid_wg_pubkey_fails_bad_key(self): - """Tests is_valid_wg_pubkey fails on bad key.""" - key = "not_a_key" - with self.assertRaises(ValueError): - app.is_valid_wg_pubkey(key) - - def test_is_valid_domain_success(self): - """Tests is_valid_domain success.""" - domain = "a" - mock_open = mock.mock_open(read_data=_VALID_CFG) - with mock.patch("builtins.open", mock_open): - self.assertEqual(domain, app.is_valid_domain(domain)) - - def test_is_valid_domain_fails_domain_not_configured(self): - """Tests is_valid_domain fails on bad domain.""" - domain = "not_Configured" - mock_open = mock.mock_open(read_data=_VALID_CFG) - with mock.patch("builtins.open", mock_open): - with self.assertRaises(ValueError): - app.is_valid_domain(domain) - - def test_KeyExchange_success(self): - """Tests creating KeyExchange successfully.""" - key = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAE=" - msg = dict(public_key=key, domain="a") - expected = app.KeyExchange(public_key=key, domain="a") - mock_open = mock.mock_open(read_data=_VALID_CFG) - with mock.patch("builtins.open", mock_open): - self.assertEqual(expected, app.KeyExchange.from_dict(msg)) - - def test_KeyExchange_fails_bad_key(self): - """Tests creating KeyExchange fails due to bad key.""" - key = "asd" - msg = dict(public_key=key, domain="a") - mock_open = mock.mock_open(read_data=_VALID_CFG) - with mock.patch("builtins.open", mock_open): - with self.assertRaises(ValueError): - app.KeyExchange.from_dict(msg) - - def test_KeyExchange_fails_bad_domain(self): - """Tests creating KeyExchange fails due to bad key.""" - key = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAE=" - msg = dict(public_key=key, domain="unconfigured_domain") - mock_open = mock.mock_open(read_data=_VALID_CFG) - with mock.patch("builtins.open", mock_open): - with self.assertRaises(ValueError): - app.KeyExchange.from_dict(msg) - - -if __name__ == "__main__": - unittest.main() diff --git a/wgkex/config/config.py b/wgkex/config/config.py index 887a2b8..efe4d0f 100644 --- a/wgkex/config/config.py +++ b/wgkex/config/config.py @@ -105,13 +105,13 @@ def load_config() -> Dict[str, str]: try: config = yaml.safe_load(cfg_contents) except yaml.YAMLError as e: - logger.error("Failed to load YAML file: %s", e) + print("Failed to load YAML file: %s", e) sys.exit(1) try: _ = Config.from_dict(config) return config except (KeyError, TypeError) as e: - logger.error("Failed to lint file: %s", e) + print("Failed to lint file: %s", e) sys.exit(2) diff --git a/wgkex/worker/BUILD b/wgkex/worker/BUILD index 77be2e6..22e2424 100644 --- a/wgkex/worker/BUILD +++ b/wgkex/worker/BUILD @@ -11,7 +11,8 @@ py_library( requirement("paho-mqtt"), requirement("pyroute2"), "//wgkex/common:utils", - "//wgkex/common:logger" + "//wgkex/common:logger", + "//wgkex/config:config", ], ) @@ -33,8 +34,8 @@ py_library( requirement("paho-mqtt"), requirement("pyroute2"), "//wgkex/common:utils", - "//wgkex/config:config", "//wgkex/common:logger", + "//wgkex/config:config", ":netlink", ], ) diff --git a/wgkex/worker/netlink.py b/wgkex/worker/netlink.py index c34f630..03af944 100644 --- a/wgkex/worker/netlink.py +++ b/wgkex/worker/netlink.py @@ -70,7 +70,7 @@ def wg_flush_stale_peers(domain: str) -> List[Dict]: stale_clients = [ stale_client for stale_client in find_stale_wireguard_clients("wg-" + domain) ] - logger.debug("Found stable clients: %s", stale_clients) + logger.debug("Found stale clients: %s", stale_clients) logger.info("Searching for stale WireGuard clients.") stale_wireguard_clients = [ WireGuardClient(public_key=stale_client, domain=domain, remove=True) @@ -97,14 +97,18 @@ def link_handler(client: WireGuardClient) -> Dict: results = dict() # Updates WireGuard peers. results.update({"Wireguard": update_wireguard_peer(client)}) + logger.debug("Handling links for %s", client) try: # Updates routes to the WireGuard Peer. results.update({"Route": route_handler(client)}) + logger.info("Updated route for %s", client) except Exception as e: # TODO(ruairi): re-raise exception here. + logger.error("Failed to update route for %s (%s)", client, e) results.update({"Route": e}) # Updates WireGuard FDB. results.update({"Bridge FDB": bridge_fdb_handler(client)}) + logger.debug("Updated Bridge FDB for %s", client) return results @@ -185,12 +189,17 @@ def find_stale_wireguard_clients(wg_interface: str) -> List: three_hrs_in_secs = int( (datetime.now() - timedelta(hours=_PEER_TIMEOUT_HOURS)).timestamp() ) + logger.info( + "Starting search for stale wireguard peers for interface %s.", wg_interface + ) with pyroute2.WireGuard() as wg: all_clients = [] - infos = wg.info(wg_interface) - for info in infos: - clients = info.get_attr("WGDEVICE_A_PEERS") - if clients is not None: + peers_on_interface = wg.info(wg_interface) + logger.info("Got infos: %s.", peers_on_interface) + for peer in peers_on_interface: + clients = peer.get_attr("WGDEVICE_A_PEERS") + logger.info("Got clients: %s.", clients) + if clients: all_clients.extend(clients) ret = [ client.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8") diff --git a/wgkex/worker/netlink_test.py b/wgkex/worker/netlink_test.py index d930704..d528b56 100644 --- a/wgkex/worker/netlink_test.py +++ b/wgkex/worker/netlink_test.py @@ -22,26 +22,16 @@ public_key="public_key", domain="del", remove=True ) -_WG_PEER_STALE = mock.Mock() -_WG_PEER_STALE.WGPEER_A_PUBLIC_KEY = {"value": b"WGPEER_A_PUBLIC_KEY_STALE"} -_WG_PEER_STALE.WGPEER_A_LAST_HANDSHAKE_TIME = { - "tv_sec": int((datetime.now() - timedelta(hours=5)).timestamp()) -} - -_WG_PEER = mock.Mock() -_WG_PEER.WGPEER_A_PUBLIC_KEY = {"value": b"WGPEER_A_PUBLIC_KEY"} -_WG_PEER.WGPEER_A_LAST_HANDSHAKE_TIME = { - "tv_sec": int((datetime.now() - timedelta(seconds=3)).timestamp()) -} - - -def _get_wg_mock(peer): - info_mock = mock.Mock() - info_mock.WGDEVICE_A_PEERS.value = [peer] + +def _get_wg_mock(key_name, stale_time): + pm = mock.Mock() + pm.get_attr.side_effect = [{"tv_sec": stale_time}, key_name.encode()] + peer_mock = mock.Mock() + peer_mock.get_attr.side_effect = [[pm]] wg_instance = WireGuard() wg_info_mock = wg_instance.__enter__.return_value wg_info_mock.set.return_value = {"WireGuard": "set"} - wg_info_mock.info.return_value = [info_mock] + wg_info_mock.info.return_value = [peer_mock] return wg_info_mock @@ -53,12 +43,18 @@ def setUp(self) -> None: def test_find_stale_wireguard_clients_success_with_non_stale_peer(self): """Tests find_stale_wireguard_clients no operation on non-stale peers.""" - wg_info_mock = _get_wg_mock(_WG_PEER) + wg_info_mock = _get_wg_mock( + "WGPEER_A_PUBLIC_KEY", + int((datetime.now() - timedelta(seconds=3)).timestamp()), + ) self.assertListEqual([], netlink.find_stale_wireguard_clients("some_interface")) def test_find_stale_wireguard_clients_success_stale_peer(self): """Tests find_stale_wireguard_clients removal of stale peer""" - wg_info_mock = _get_wg_mock(_WG_PEER_STALE) + wg_info_mock = _get_wg_mock( + "WGPEER_A_PUBLIC_KEY_STALE", + int((datetime.now() - timedelta(hours=5)).timestamp()), + ) self.assertListEqual( ["WGPEER_A_PUBLIC_KEY_STALE"], netlink.find_stale_wireguard_clients("some_interface"), @@ -69,7 +65,7 @@ def test_route_handler_add_success(self): self.route_info_mock.route.return_value = {"key": "value"} self.assertDictEqual({"key": "value"}, netlink.route_handler(_WG_CLIENT_ADD)) self.route_info_mock.route.assert_called_with( - "add", dst="fe80::282:6eff:fe9d:ecd3/128", oif=mock.ANY + "replace", dst="fe80::282:6eff:fe9d:ecd3/128", oif=mock.ANY ) def test_route_handler_remove_success(self): @@ -82,7 +78,10 @@ def test_route_handler_remove_success(self): def test_update_wireguard_peer_success(self): """Test update_wireguard_peer for normal operation.""" - wg_info_mock = _get_wg_mock(_WG_PEER) + wg_info_mock = _get_wg_mock( + "WGPEER_A_PUBLIC_KEY", + int((datetime.now() - timedelta(seconds=3)).timestamp()), + ) self.assertDictEqual( {"WireGuard": "set"}, netlink.update_wireguard_peer(_WG_CLIENT_ADD) ) @@ -102,9 +101,10 @@ def test_bridge_fdb_handler_append_success(self): self.assertEqual({"key": "value"}, netlink.bridge_fdb_handler(_WG_CLIENT_ADD)) self.route_info_mock.fdb.assert_called_with( "append", - ifindex=mock.ANY, lladdr="00:00:00:00:00:00", dst="fe80::282:6eff:fe9d:ecd3", + ifindex=mock.ANY, + NDA_IFINDEX=mock.ANY, ) def test_bridge_fdb_handler_del_success(self): @@ -114,6 +114,7 @@ def test_bridge_fdb_handler_del_success(self): self.route_info_mock.fdb.assert_called_with( "del", ifindex=mock.ANY, + NDA_IFINDEX=mock.ANY, lladdr="00:00:00:00:00:00", dst="fe80::282:6eff:fe9d:ecd3", ) @@ -125,7 +126,10 @@ def test_link_handler_addition_success(self): "Route": {"IPRoute": "route"}, "Bridge FDB": {"IPRoute": "fdb"}, } - wg_info_mock = _get_wg_mock(_WG_PEER) + wg_info_mock = _get_wg_mock( + "WGPEER_A_PUBLIC_KEY", + int((datetime.now() - timedelta(seconds=3)).timestamp()), + ) wg_info_mock.set.return_value = {"WireGuard": "set"} self.route_info_mock.fdb.return_value = {"IPRoute": "fdb"} self.route_info_mock.route.return_value = {"IPRoute": "route"} @@ -133,11 +137,12 @@ def test_link_handler_addition_success(self): self.route_info_mock.fdb.assert_called_with( "append", ifindex=mock.ANY, + NDA_IFINDEX=mock.ANY, lladdr="00:00:00:00:00:00", dst="fe80::282:6eff:fe9d:ecd3", ) self.route_info_mock.route.assert_called_with( - "add", dst="fe80::282:6eff:fe9d:ecd3/128", oif=mock.ANY + "replace", dst="fe80::282:6eff:fe9d:ecd3/128", oif=mock.ANY ) wg_info_mock.set.assert_called_with( "wg-add", @@ -151,7 +156,10 @@ def test_link_handler_addition_success(self): def test_wg_flush_stale_peers_not_stale_success(self): """Tests processing of non-stale WireGuard Peer.""" - wg_info_mock = _get_wg_mock(_WG_PEER) + wg_info_mock = _get_wg_mock( + "WGPEER_A_PUBLIC_KEY", + int((datetime.now() - timedelta(seconds=3)).timestamp()), + ) self.route_info_mock.fdb.return_value = {"IPRoute": "fdb"} self.route_info_mock.route.return_value = {"IPRoute": "route"} self.assertListEqual([], netlink.wg_flush_stale_peers("domain")) @@ -169,7 +177,10 @@ def test_wg_flush_stale_peers_stale_success(self): ] self.route_info_mock.fdb.return_value = {"IPRoute": "fdb"} self.route_info_mock.route.return_value = {"IPRoute": "route"} - wg_info_mock = _get_wg_mock(_WG_PEER_STALE) + wg_info_mock = _get_wg_mock( + "WGPEER_A_PUBLIC_KEY_STALE", + int((datetime.now() - timedelta(hours=5)).timestamp()), + ) wg_info_mock.set.return_value = {"WireGuard": "set"} self.assertListEqual(expected, netlink.wg_flush_stale_peers("domain")) self.route_info_mock.route.assert_called_with(