forked from vaseline555/Federated-Learning-in-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
70 lines (56 loc) · 2.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import time
import datetime
import pickle
import yaml
import threading
import logging
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from src.server import Server
from src.utils import launch_tensor_board
if __name__ == "__main__":
# read configuration file
with open('./config.yaml') as c:
configs = list(yaml.load_all(c, Loader=yaml.FullLoader))
global_config = configs[0]["global_config"]
data_config = configs[1]["data_config"]
fed_config = configs[2]["fed_config"]
optim_config = configs[3]["optim_config"]
init_config = configs[4]["init_config"]
model_config = configs[5]["model_config"]
log_config = configs[6]["log_config"]
# modify log_path to contain current time
log_config["log_path"] = os.path.join(log_config["log_path"], str(datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")))
# initiate TensorBaord for tracking losses and metrics
writer = SummaryWriter(log_dir=log_config["log_path"], filename_suffix="FL")
tb_thread = threading.Thread(
target=launch_tensor_board,
args=([log_config["log_path"], log_config["tb_port"], log_config["tb_host"]])
).start()
time.sleep(3.0)
# set the configuration of global logger
logger = logging.getLogger(__name__)
logging.basicConfig(
filename=os.path.join(log_config["log_path"], log_config["log_name"]),
level=logging.INFO,
format="[%(levelname)s](%(asctime)s) %(message)s",
datefmt="%Y/%m/%d/ %I:%M:%S %p")
# display and log experiment configuration
message = "\n[WELCOME] Unfolding configurations...!"
print(message); logging.info(message)
for config in configs:
print(config); logging.info(config)
print()
# initialize federated learning
central_server = Server(writer, model_config, global_config, data_config, init_config, fed_config, optim_config)
central_server.setup()
# do federated learning
central_server.fit()
# save resulting losses and metrics
with open(os.path.join(log_config["log_path"], "result.pkl"), "wb") as f:
pickle.dump(central_server.results, f)
# bye!
message = "...done all learning process!\n...exit program!"
print(message); logging.info(message)
time.sleep(3); exit()