Skip to content

Commit

Permalink
Add more tests for worker/netlink.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DasSkelett authored and sealrealize committed Dec 17, 2023
1 parent fdb287c commit a2bf54d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 6 deletions.
2 changes: 1 addition & 1 deletion wgkex/worker/netlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_connected_peers_count(wg_interface: str) -> int:
return count


def get_device_data(wg_interface: str) -> Tuple[Any, Any, Any]:
def get_device_data(wg_interface: str) -> Tuple[int, str, str]:
"""Returns the listening port, public key and local IP address.
Arguments:
Expand Down
67 changes: 62 additions & 5 deletions wgkex/worker/netlink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,31 @@
)


def _get_wg_mock(key_name, stale_time):
pm = mock.Mock()
pm.get_attr.side_effect = [{"tv_sec": stale_time}, key_name.encode()]
def _get_peer_mock(public_key, last_handshake_time):
def peer_get_attr(attr: str):
if attr == "WGPEER_A_LAST_HANDSHAKE_TIME":
return {"tv_sec": last_handshake_time}
if attr == "WGPEER_A_PUBLIC_KEY":
return public_key.encode()

peer_mock = mock.Mock()
peer_mock.get_attr.side_effect = [[pm]]
peer_mock.get_attr.side_effect = peer_get_attr
return peer_mock


def _get_wg_mock(public_key, last_handshake_time):
peer_mock = _get_peer_mock(public_key, last_handshake_time)

def msg_get_attr(attr: str):
if attr == "WGDEVICE_A_PEERS":
return [peer_mock]

msg_mock = mock.Mock()
msg_mock.get_attr.side_effect = msg_get_attr
wg_instance = WireGuard()
wg_info_mock = wg_instance.__enter__.return_value
wg_info_mock.set.return_value = {"WireGuard": "set"}
wg_info_mock.info.return_value = [peer_mock]
wg_info_mock.info.return_value = [msg_mock]
return wg_info_mock


Expand Down Expand Up @@ -188,6 +204,47 @@ def test_wg_flush_stale_peers_stale_success(self):
"del", dst="fe80::281:16ff:fe49:395e/128", oif=mock.ANY
)

def test_get_connected_peers_count_success(self):
"""Tests getting the correct number of connected peers for an interface."""
peers = []
for i in range(10):
peer_mock = _get_peer_mock(
"TEST_KEY",
int((datetime.now() - timedelta(minutes=i, seconds=5)).timestamp()),
)
peers.append(peer_mock)

def msg_get_attr(attr: str):
if attr == "WGDEVICE_A_PEERS":
return peers

msg_mock = mock.Mock()
msg_mock.get_attr.side_effect = msg_get_attr

wg_instance = WireGuard()
wg_info_mock = wg_instance.__enter__.return_value
wg_info_mock.info.return_value = [msg_mock]

ret = netlink.get_connected_peers_count("wg-welt")
self.assertEqual(ret, 3)

def test_get_device_data_success(self):
def msg_get_attr(attr: str):
if attr == "WGDEVICE_A_LISTEN_PORT":
return 51820
if attr == "WGDEVICE_A_PUBLIC_KEY":
return "TEST_PUBLIC_KEY".encode("ascii")

msg_mock = mock.Mock()
msg_mock.get_attr.side_effect = msg_get_attr

wg_instance = WireGuard()
wg_info_mock = wg_instance.__enter__.return_value
wg_info_mock.info.return_value = [msg_mock]

ret = netlink.get_device_data("wg-welt")
self.assertTupleEqual(ret, (51820, "TEST_PUBLIC_KEY", mock.ANY))


if __name__ == "__main__":
unittest.main()

0 comments on commit a2bf54d

Please sign in to comment.