Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests fixup #70

Merged
merged 1 commit into from
Sep 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions wgkex/broker/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,3 @@ py_binary(
"//wgkex/config:config",
],
)

py_test(
name="app_test",
srcs=["app_test.py"],
deps=[
":app",
requirement("mock"),
],
)
90 changes: 0 additions & 90 deletions wgkex/broker/app_test.py

This file was deleted.

4 changes: 2 additions & 2 deletions wgkex/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def load_config() -> Dict[str, str]:
try:
config = yaml.safe_load(cfg_contents)
except yaml.YAMLError as e:
logger.error("Failed to load YAML file: %s", e)
print("Failed to load YAML file: %s", e)
sys.exit(1)
try:
_ = Config.from_dict(config)
return config
except (KeyError, TypeError) as e:
logger.error("Failed to lint file: %s", e)
print("Failed to lint file: %s", e)
sys.exit(2)


Expand Down
5 changes: 3 additions & 2 deletions wgkex/worker/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ py_library(
requirement("paho-mqtt"),
requirement("pyroute2"),
"//wgkex/common:utils",
"//wgkex/common:logger"
"//wgkex/common:logger",
"//wgkex/config:config",
],
)

Expand All @@ -33,8 +34,8 @@ py_library(
requirement("paho-mqtt"),
requirement("pyroute2"),
"//wgkex/common:utils",
"//wgkex/config:config",
"//wgkex/common:logger",
"//wgkex/config:config",
":netlink",
],
)
Expand Down
19 changes: 14 additions & 5 deletions wgkex/worker/netlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def wg_flush_stale_peers(domain: str) -> List[Dict]:
stale_clients = [
stale_client for stale_client in find_stale_wireguard_clients("wg-" + domain)
]
logger.debug("Found stable clients: %s", stale_clients)
logger.debug("Found stale clients: %s", stale_clients)
logger.info("Searching for stale WireGuard clients.")
stale_wireguard_clients = [
WireGuardClient(public_key=stale_client, domain=domain, remove=True)
Expand All @@ -97,14 +97,18 @@ def link_handler(client: WireGuardClient) -> Dict:
results = dict()
# Updates WireGuard peers.
results.update({"Wireguard": update_wireguard_peer(client)})
logger.debug("Handling links for %s", client)
try:
# Updates routes to the WireGuard Peer.
results.update({"Route": route_handler(client)})
logger.info("Updated route for %s", client)
except Exception as e:
# TODO(ruairi): re-raise exception here.
logger.error("Failed to update route for %s (%s)", client, e)
results.update({"Route": e})
# Updates WireGuard FDB.
results.update({"Bridge FDB": bridge_fdb_handler(client)})
logger.debug("Updated Bridge FDB for %s", client)
return results


Expand Down Expand Up @@ -185,12 +189,17 @@ def find_stale_wireguard_clients(wg_interface: str) -> List:
three_hrs_in_secs = int(
(datetime.now() - timedelta(hours=_PEER_TIMEOUT_HOURS)).timestamp()
)
logger.info(
"Starting search for stale wireguard peers for interface %s.", wg_interface
)
with pyroute2.WireGuard() as wg:
all_clients = []
infos = wg.info(wg_interface)
for info in infos:
clients = info.get_attr("WGDEVICE_A_PEERS")
if clients is not None:
peers_on_interface = wg.info(wg_interface)
logger.info("Got infos: %s.", peers_on_interface)
for peer in peers_on_interface:
clients = peer.get_attr("WGDEVICE_A_PEERS")
logger.info("Got clients: %s.", clients)
if clients:
all_clients.extend(clients)
ret = [
client.get_attr("WGPEER_A_PUBLIC_KEY").decode("utf-8")
Expand Down
63 changes: 37 additions & 26 deletions wgkex/worker/netlink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,16 @@
public_key="public_key", domain="del", remove=True
)

_WG_PEER_STALE = mock.Mock()
_WG_PEER_STALE.WGPEER_A_PUBLIC_KEY = {"value": b"WGPEER_A_PUBLIC_KEY_STALE"}
_WG_PEER_STALE.WGPEER_A_LAST_HANDSHAKE_TIME = {
"tv_sec": int((datetime.now() - timedelta(hours=5)).timestamp())
}

_WG_PEER = mock.Mock()
_WG_PEER.WGPEER_A_PUBLIC_KEY = {"value": b"WGPEER_A_PUBLIC_KEY"}
_WG_PEER.WGPEER_A_LAST_HANDSHAKE_TIME = {
"tv_sec": int((datetime.now() - timedelta(seconds=3)).timestamp())
}


def _get_wg_mock(peer):
info_mock = mock.Mock()
info_mock.WGDEVICE_A_PEERS.value = [peer]

def _get_wg_mock(key_name, stale_time):
pm = mock.Mock()
pm.get_attr.side_effect = [{"tv_sec": stale_time}, key_name.encode()]
peer_mock = mock.Mock()
peer_mock.get_attr.side_effect = [[pm]]
wg_instance = WireGuard()
wg_info_mock = wg_instance.__enter__.return_value
wg_info_mock.set.return_value = {"WireGuard": "set"}
wg_info_mock.info.return_value = [info_mock]
wg_info_mock.info.return_value = [peer_mock]
return wg_info_mock


Expand All @@ -53,12 +43,18 @@ def setUp(self) -> None:

def test_find_stale_wireguard_clients_success_with_non_stale_peer(self):
"""Tests find_stale_wireguard_clients no operation on non-stale peers."""
wg_info_mock = _get_wg_mock(_WG_PEER)
wg_info_mock = _get_wg_mock(
"WGPEER_A_PUBLIC_KEY",
int((datetime.now() - timedelta(seconds=3)).timestamp()),
)
self.assertListEqual([], netlink.find_stale_wireguard_clients("some_interface"))

def test_find_stale_wireguard_clients_success_stale_peer(self):
"""Tests find_stale_wireguard_clients removal of stale peer"""
wg_info_mock = _get_wg_mock(_WG_PEER_STALE)
wg_info_mock = _get_wg_mock(
"WGPEER_A_PUBLIC_KEY_STALE",
int((datetime.now() - timedelta(hours=5)).timestamp()),
)
self.assertListEqual(
["WGPEER_A_PUBLIC_KEY_STALE"],
netlink.find_stale_wireguard_clients("some_interface"),
Expand All @@ -69,7 +65,7 @@ def test_route_handler_add_success(self):
self.route_info_mock.route.return_value = {"key": "value"}
self.assertDictEqual({"key": "value"}, netlink.route_handler(_WG_CLIENT_ADD))
self.route_info_mock.route.assert_called_with(
"add", dst="fe80::282:6eff:fe9d:ecd3/128", oif=mock.ANY
"replace", dst="fe80::282:6eff:fe9d:ecd3/128", oif=mock.ANY
)

def test_route_handler_remove_success(self):
Expand All @@ -82,7 +78,10 @@ def test_route_handler_remove_success(self):

def test_update_wireguard_peer_success(self):
"""Test update_wireguard_peer for normal operation."""
wg_info_mock = _get_wg_mock(_WG_PEER)
wg_info_mock = _get_wg_mock(
"WGPEER_A_PUBLIC_KEY",
int((datetime.now() - timedelta(seconds=3)).timestamp()),
)
self.assertDictEqual(
{"WireGuard": "set"}, netlink.update_wireguard_peer(_WG_CLIENT_ADD)
)
Expand All @@ -102,9 +101,10 @@ def test_bridge_fdb_handler_append_success(self):
self.assertEqual({"key": "value"}, netlink.bridge_fdb_handler(_WG_CLIENT_ADD))
self.route_info_mock.fdb.assert_called_with(
"append",
ifindex=mock.ANY,
lladdr="00:00:00:00:00:00",
dst="fe80::282:6eff:fe9d:ecd3",
ifindex=mock.ANY,
NDA_IFINDEX=mock.ANY,
)

def test_bridge_fdb_handler_del_success(self):
Expand All @@ -114,6 +114,7 @@ def test_bridge_fdb_handler_del_success(self):
self.route_info_mock.fdb.assert_called_with(
"del",
ifindex=mock.ANY,
NDA_IFINDEX=mock.ANY,
lladdr="00:00:00:00:00:00",
dst="fe80::282:6eff:fe9d:ecd3",
)
Expand All @@ -125,19 +126,23 @@ def test_link_handler_addition_success(self):
"Route": {"IPRoute": "route"},
"Bridge FDB": {"IPRoute": "fdb"},
}
wg_info_mock = _get_wg_mock(_WG_PEER)
wg_info_mock = _get_wg_mock(
"WGPEER_A_PUBLIC_KEY",
int((datetime.now() - timedelta(seconds=3)).timestamp()),
)
wg_info_mock.set.return_value = {"WireGuard": "set"}
self.route_info_mock.fdb.return_value = {"IPRoute": "fdb"}
self.route_info_mock.route.return_value = {"IPRoute": "route"}
self.assertEqual(expected, netlink.link_handler(_WG_CLIENT_ADD))
self.route_info_mock.fdb.assert_called_with(
"append",
ifindex=mock.ANY,
NDA_IFINDEX=mock.ANY,
lladdr="00:00:00:00:00:00",
dst="fe80::282:6eff:fe9d:ecd3",
)
self.route_info_mock.route.assert_called_with(
"add", dst="fe80::282:6eff:fe9d:ecd3/128", oif=mock.ANY
"replace", dst="fe80::282:6eff:fe9d:ecd3/128", oif=mock.ANY
)
wg_info_mock.set.assert_called_with(
"wg-add",
Expand All @@ -151,7 +156,10 @@ def test_link_handler_addition_success(self):

def test_wg_flush_stale_peers_not_stale_success(self):
"""Tests processing of non-stale WireGuard Peer."""
wg_info_mock = _get_wg_mock(_WG_PEER)
wg_info_mock = _get_wg_mock(
"WGPEER_A_PUBLIC_KEY",
int((datetime.now() - timedelta(seconds=3)).timestamp()),
)
self.route_info_mock.fdb.return_value = {"IPRoute": "fdb"}
self.route_info_mock.route.return_value = {"IPRoute": "route"}
self.assertListEqual([], netlink.wg_flush_stale_peers("domain"))
Expand All @@ -169,7 +177,10 @@ def test_wg_flush_stale_peers_stale_success(self):
]
self.route_info_mock.fdb.return_value = {"IPRoute": "fdb"}
self.route_info_mock.route.return_value = {"IPRoute": "route"}
wg_info_mock = _get_wg_mock(_WG_PEER_STALE)
wg_info_mock = _get_wg_mock(
"WGPEER_A_PUBLIC_KEY_STALE",
int((datetime.now() - timedelta(hours=5)).timestamp()),
)
wg_info_mock.set.return_value = {"WireGuard": "set"}
self.assertListEqual(expected, netlink.wg_flush_stale_peers("domain"))
self.route_info_mock.route.assert_called_with(
Expand Down