-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathgenerate_config_files.py
50 lines (41 loc) · 1.14 KB
/
generate_config_files.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import toml
import typer
from nilmth.utils.config import load_config
LIST_METHODS = ["mp", "vs", "at"]
LIST_MODELS = ["ConvModel", "GRUModel"]
LIST_CLASS_W = [
0,
0.01,
0.05,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
0.95,
0.99,
1,
]
def main(path_config="nilmth/config.toml", path_output="configs"):
config = load_config(path_config)
if not os.path.exists(path_output):
os.mkdir(path_output)
for model_name in LIST_MODELS:
config["model"].update({"name": model_name})
for method in LIST_METHODS:
config["model"]["threshold"].update({"method": method})
for class_w in LIST_CLASS_W:
config["model"].update({"classification_w": class_w})
config["model"].update({"regression_w": round(1 - class_w, 2)})
path_dump = os.path.join(
path_output, f"{model_name}_{method}_classw_{class_w}.toml"
)
with open(path_dump, "w") as toml_file:
toml.dump(config, toml_file)
if __name__ == "__main__":
typer.run(main)