Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhancement: enable lazy setting of nested dicts #775

Merged
merged 11 commits into from
Jul 18, 2024
36 changes: 29 additions & 7 deletions garak/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# logging should be set up before config is loaded

from collections import defaultdict
from dataclasses import dataclass
import importlib
import logging
Expand All @@ -15,6 +16,8 @@
from typing import List
import yaml

DICT_CONFIG_AFTER_LOAD = False

version = -1 # eh why this is here? hm. who references it

system_params = (
Expand Down Expand Up @@ -60,11 +63,28 @@ class TransientConfig(GarakSubConfig):
run = GarakSubConfig()
plugins = GarakSubConfig()
reporting = GarakSubConfig()
plugins.probes = {}
plugins.generators = {}
plugins.detectors = {}
plugins.buffs = {}
plugins.harnesses = {}


def _lock_config_as_dict():
global plugins
for plugin_type in ("probes", "generators", "buffs", "detectors", "harnesses"):
setattr(plugins, plugin_type, _crystallise(getattr(plugins, plugin_type)))


def _crystallise(d):
for k in d.keys():
if isinstance(d[k], defaultdict):
d[k] = _crystallise(d[k])
return dict(d)


nested_dict = lambda: defaultdict(nested_dict)

plugins.probes = nested_dict()
plugins.generators = nested_dict()
plugins.detectors = nested_dict()
plugins.buffs = nested_dict()
plugins.harnesses = nested_dict()
reporting.taxonomy = None # set here to enable report_digest to be called directly

buffmanager = BuffManager()
Expand All @@ -87,7 +107,7 @@ def _set_settings(config_obj, settings_obj: dict):
def _combine_into(d: dict, combined: dict) -> None:
for k, v in d.items():
if isinstance(v, dict):
_combine_into(v, combined.setdefault(k, {}))
_combine_into(v, combined.setdefault(k, nested_dict()))
else:
combined[k] = v
return combined
Expand All @@ -96,7 +116,7 @@ def _combine_into(d: dict, combined: dict) -> None:
def _load_yaml_config(settings_filenames) -> dict:
global config_files
config_files += settings_filenames
config = {}
config = nested_dict()
for settings_filename in settings_filenames:
with open(settings_filename, encoding="utf-8") as settings_file:
settings = yaml.safe_load(settings_file)
Expand Down Expand Up @@ -155,6 +175,8 @@ def load_config(

logging.debug("Loading configs from: %s", ",".join(settings_files))
_store_config(settings_files=settings_files)
if DICT_CONFIG_AFTER_LOAD:
_lock_config_as_dict()
loaded = True


Expand Down
7 changes: 7 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,10 @@ def test_report_prefix_with_hitlog_no_explode():
assert os.path.isfile("kjsfhgkjahpsfdg.report.jsonl")
assert os.path.isfile("kjsfhgkjahpsfdg.report.html")
assert os.path.isfile("kjsfhgkjahpsfdg.hitlog.jsonl")


def test_nested():
importlib.reload(_config)

_config.plugins.generators["a"]["b"]["c"]["d"] = "e"
assert _config.plugins.generators["a"]["b"]["c"]["d"] == "e"
Loading