Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR to centralising logging configuration #66

Merged
merged 5 commits into from
Sep 11, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions wgkex/broker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@
import paho.mqtt.client as mqtt_client

from wgkex.config import config
from wgkex.common import logger

logging.basicConfig(
format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=config.load_config().get("log_level"),
)

WG_PUBKEY_PATTERN = re.compile(r"^[A-Za-z0-9+/]{42}[AEIMQUYcgkosw480]=$")

Expand Down Expand Up @@ -94,7 +90,7 @@ def wg_key_exchange() -> Tuple[str, int]:
domain = data.domain
# in case we want to decide here later we want to publish it only to dedicated gateways
gateway = "all"
logging.info(f"wg_key_exchange: Domain: {domain}, Key:{key}")
logger.info(f"wg_key_exchange: Domain: {domain}, Key:{key}")

mqtt.publish(f"wireguard/{domain}/{gateway}", key)
return jsonify({"Message": "OK"}), 200
Expand All @@ -106,7 +102,7 @@ def handle_mqtt_connect(
) -> None:
"""Prints status of connect message."""
# TODO(ruairi): Clarify current usage of this function.
logging.debug(
logger.debug(
"MQTT connected to {}:{}".format(
app.config["MQTT_BROKER_URL"], app.config["MQTT_BROKER_PORT"]
)
Expand All @@ -120,7 +116,7 @@ def handle_mqtt_message(
) -> None:
"""Prints message contents."""
# TODO(ruairi): Clarify current usage of this function.
logging.debug(
logger.debug(
f"MQTT message received on {message.topic}: {message.payload.decode()}"
)

Expand Down
6 changes: 6 additions & 0 deletions wgkex/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ py_test(
requirement("mock"),
],
)

py_library(
name = "logger",
srcs = ["logger.py"],
visibility = ["//visibility:public"]
)
13 changes: 13 additions & 0 deletions wgkex/common/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from logging import basicConfig
from logging import DEBUG
from logging import info as info
from logging import warning as warning
from logging import error as error
from logging import critical as critical
from logging import debug as debug

basicConfig(
format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=DEBUG,
awlx marked this conversation as resolved.
Show resolved Hide resolved
awlx marked this conversation as resolved.
Show resolved Hide resolved
)
5 changes: 4 additions & 1 deletion wgkex/config/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ py_library(
name="config",
srcs=["config.py"],
visibility=["//visibility:public"],
deps=[requirement("PyYAML"), "//wgkex/common:utils"],
deps=[requirement("PyYAML"),
"//wgkex/common:utils",
"//wgkex/common:logger",
],
)

py_test(
Expand Down
12 changes: 3 additions & 9 deletions wgkex/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from functools import lru_cache
from typing import Dict, Union, Any, List, Optional
import dataclasses
import logging

from wgkex.common import logger

class Error(Exception):
"""Base Exception handling class."""
Expand All @@ -18,11 +17,6 @@ class ConfigFileNotFoundError(Error):

WG_CONFIG_OS_ENV = "WGKEX_CONFIG_FILE"
WG_CONFIG_DEFAULT_LOCATION = "/etc/wgkex.yaml"
logging.basicConfig(
format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.DEBUG,
)


@dataclasses.dataclass
Expand Down Expand Up @@ -113,13 +107,13 @@ def load_config() -> Dict[str, str]:
try:
config = yaml.safe_load(cfg_contents)
except yaml.YAMLError as e:
logging.error("Failed to load YAML file: %s", e)
logger.error("Failed to load YAML file: %s", e)
sys.exit(1)
try:
_ = Config.from_dict(config)
return config
except (KeyError, TypeError) as e:
logging.error("Failed to lint file: %s", e)
logger.error("Failed to lint file: %s", e)
sys.exit(2)


Expand Down
5 changes: 4 additions & 1 deletion wgkex/worker/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ py_library(
requirement("NetLink"),
requirement("paho-mqtt"),
requirement("pyroute2"),
"//wgkex/common:utils"
"//wgkex/common:utils",
"//wgkex/common:logger"
],
)

Expand All @@ -33,6 +34,7 @@ py_library(
requirement("pyroute2"),
"//wgkex/common:utils",
"//wgkex/config:config",
"//wgkex/common:logger",
":netlink",
],
)
Expand All @@ -52,6 +54,7 @@ py_binary(
deps = [
":mqtt",
"//wgkex/config:config",
"//wgkex/common:logger",
],
)

Expand Down
19 changes: 6 additions & 13 deletions wgkex/worker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,9 @@
from wgkex.worker.netlink import wg_flush_stale_peers
import threading
import time
import logging
import datetime
from wgkex.common import logger
from typing import List, Text

logging.basicConfig(
format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=config.load_config().get("log_level"),
)

_CLEANUP_TIME = 3600


Expand All @@ -30,8 +23,8 @@ def flush_workers(domain: Text) -> None:
"""Calls peer flush every _CLEANUP_TIME interval."""
while True:
time.sleep(_CLEANUP_TIME)
logging.info(f"Running cleanup task for {domain}")
logging.info("Cleaned up domains: %s", wg_flush_stale_peers(domain))
logger.info(f"Running cleanup task for {domain}")
logger.info("Cleaned up domains: %s", wg_flush_stale_peers(domain))


def clean_up_worker(domains: List[Text]) -> None:
Expand All @@ -40,14 +33,14 @@ def clean_up_worker(domains: List[Text]) -> None:
Arguments:
domains: list of domains.
"""
logging.debug("Cleaning up the following domains: %s", domains)
logger.debug("Cleaning up the following domains: %s", domains)
prefix = config.load_config().get("domain_prefix")
for domain in domains:
logging.info("Scheduling cleanup task for %s, ", domain)
logger.info("Scheduling cleanup task for %s, ", domain)
try:
cleaned_domain = domain.split(prefix)[1]
except IndexError:
logging.error(
logger.error(
"Cannot strip domain with prefix %s from passed value %s. Skipping cleanup operation",
prefix,
domain,
Expand Down
2 changes: 1 addition & 1 deletion wgkex/worker/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_main_success(self, connect_mock, config_mock):
)
with mock.patch("app.flush_workers", return_value=None):
app.main()
connect_mock.assert_called_with(["TEST_PREFIX_domain.one"])
connect_mock.assert_called_with()

@mock.patch.object(app.config, "load_config")
@mock.patch.object(app.mqtt, "connect", autospec=True)
Expand Down
27 changes: 12 additions & 15 deletions wgkex/worker/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
from wgkex.worker.netlink import link_handler
from wgkex.worker.netlink import WireGuardClient
from typing import Optional, Dict, List, Any, Union
import logging

logging.basicConfig(
format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=load_config().get("log_level"),
)
from wgkex.common import logger


def fetch_from_config(var: str) -> Optional[Union[Dict[str, str], str]]:
Expand Down Expand Up @@ -54,7 +48,7 @@ def connect() -> None:
# Register handlers
client.on_connect = on_connect
client.on_message = on_message
logging.info("connecting to broker %s", broker_address)
logger.info("connecting to broker %s", broker_address)

client.connect(broker_address, port=broker_port, keepalive=broker_keepalive)
client.loop_forever()
Expand All @@ -70,14 +64,14 @@ def on_connect(client: mqtt.Client, userdata: Any, flags, rc) -> None:
flags: The MQTT flags.
rc: The MQTT rc.
"""
logging.debug("Connected with result code " + str(rc))
logger.debug("Connected with result code " + str(rc))
domains = load_config().get("domains")

# Subscribing in on_connect() means that if we lose the connection and
# reconnect then subscriptions will be renewed.
for domain in domains:
topic = f"wireguard/{domain}/+"
logging.info(f"Subscribing to topic {topic}")
logger.info(f"Subscribing to topic {topic}")
client.subscribe(topic)


Expand All @@ -90,15 +84,18 @@ def on_message(client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage) ->
message: The MQTT message.
"""
# TODO(ruairi): Check bounds and raise exception here.
logging.debug("Got message %s from MTQQ", message)
logger.debug("Got message %s from MTQQ", message)
domain_prefix = load_config().get("domain_prefix")
domain = re.search(r"/.*" + domain_prefix + "(\w+)/", message.topic).group(1)
logging.debug("Found domain %s", domain)
domain = re.search(r"/.*" + domain_prefix + "(\w+)/", message.topic)
if not domain:
raise ValueError('Could not find a match for %s on %s', domain_prefix, message.topic)
domain = domain.group(1)
logger.debug("Found domain %s", domain)
client = WireGuardClient(
public_key=str(message.payload.decode("utf-8")),
domain=domain,
remove=False,
)
logging.info(f"Received create message for key {client.public_key} on domain {domain} with lladdr {client.lladdr}")
logger.info(f"Received create message for key {client.public_key} on domain {domain} with lladdr {client.lladdr}")
# TODO(ruairi): Verify return type here.
logging.debug(link_handler(client))
logger.debug(link_handler(client))
16 changes: 10 additions & 6 deletions wgkex/worker/mqtt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import mock
import mqtt


class MQTTTest(unittest.TestCase):

@mock.patch.object(mqtt, "load_config")
def test_fetch_from_config_success(self, config_mock):
"""Ensure we can fetch a value from config."""
Expand All @@ -25,7 +25,7 @@ def test_connect_success(self, config_mock, hostname_mock, mqtt_mock):
"""Tests successful connection to MQTT server."""
hostname_mock.return_value = "hostname"
config_mock.return_value = dict(mqtt={"broker_url": "some_url"})
mqtt.connect(["domain1", "domain2"])
mqtt.connect()
mqtt_mock.assert_has_calls(
[mock.call().connect("some_url", port=None, keepalive=None)],
any_order=True,
Expand All @@ -38,11 +38,13 @@ def test_connect_fails_mqtt_error(self, config_mock, mqtt_mock):
mqtt_mock.side_effect = ValueError("barf")
config_mock.return_value = dict(mqtt={"broker_url": "some_url"})
with self.assertRaises(ValueError):
mqtt.connect(["domain1", "domain2"])
mqtt.connect()

@mock.patch.object(mqtt, "link_handler")
def test_on_message_success(self, link_mock):
@mock.patch.object(mqtt, "load_config")
def test_on_message_success(self, config_mock, link_mock):
"""Tests on_message for success."""
config_mock.return_value = {'domain_prefix': '_ffmuc_'}
link_mock.return_value = dict(WireGuard="result")
mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage")
mqtt_msg.topic = "/_ffmuc_domain1/"
Expand All @@ -60,12 +62,14 @@ def test_on_message_success(self, link_mock):
)

@mock.patch.object(mqtt, "link_handler")
def test_on_message_fails_no_domain(self, link_mock):
@mock.patch.object(mqtt, "load_config")
def test_on_message_fails_no_domain(self, config_mock, link_mock):
"""Tests on_message for failure to parse domain."""
config_mock.return_value = {'domain_prefix': 'ffmuc_', 'log_level': 'DEBUG', 'domains': ['a', 'b'], 'mqtt': {'broker_port': 1883, 'broker_url': 'mqtt://broker', 'keepalive': 5, 'password': 'pass', 'tls': True, 'username': 'user'}}
link_mock.return_value = dict(WireGuard="result")
mqtt_msg = mock.patch.object(mqtt.mqtt, "MQTTMessage")
mqtt_msg.topic = "bad_domain_match"
with self.assertRaises(AttributeError):
with self.assertRaises(ValueError):
mqtt.on_message(None, None, mqtt_msg)


Expand Down
7 changes: 7 additions & 0 deletions wgkex/worker/netlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pyroute2

from wgkex.common.utils import mac2eui64
from wgkex.common import logger

_PERSISTENT_KEEPALIVE_SECONDS = 15
_PEER_TIMEOUT_HOURS = 3
Expand Down Expand Up @@ -65,16 +66,22 @@ def wg_flush_stale_peers(domain: str) -> List[Dict]:
Returns:
The peers which we can remove.
"""
logger.info('Searching for stale clients for %s', domain)
stale_clients = [
stale_client for stale_client in find_stale_wireguard_clients("wg-" + domain)
]
logger.debug('Found stable clients: %s', stale_clients)
logger.info('Searching for stale WireGuard clients.')
stale_wireguard_clients = [
WireGuardClient(public_key=stale_client, domain=domain, remove=True)
for stale_client in stale_clients
]
logger.debug('Found stable WireGuard clients: %s', stale_wireguard_clients)
logger.info('Processing clients.')
link_handled = [
link_handler(stale_client) for stale_client in stale_wireguard_clients
]
logger.debug('Handled the following clients: %s', link_handled)
return link_handled


Expand Down