-
Notifications
You must be signed in to change notification settings - Fork 8
/
Fed_utils.py
76 lines (55 loc) · 2.02 KB
/
Fed_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from tkinter import N
import torch.nn as nn
import torch
import copy
from torchvision import transforms
import numpy as np
from torch.nn import functional as F
from PIL import Image
import torch.optim as optim
from myNetwork import *
from torch.utils.data import DataLoader
import random
from train import Trainer
from train_rcil import Trainer_rcil
from apex.parallel import DistributedDataParallel
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def local_train(args, clients, index, model_g, current_step, ep_g):
clients[index].beforeTrain(args, current_step)
if args.base_weights == False:
if args.use_entropy_detection==True:
clients[index].update_entropy_signal(model_g)
local_model = clients[index].train(args, model_g, ep_g)
else:
local_model = None
return local_model
def FedAvg(models):
w_avg = copy.deepcopy(models[0])
for k in w_avg.keys():
for i in range(1, len(models)):
w_avg[k] += models[i][k]
w_avg[k] = torch.div(w_avg[k], len(models))
return w_avg
def model_global_eval(args, model_g, test_loader, current_step, val_metrics,device,rank):
tmp_model_g = copy.deepcopy(model_g)
tmp_model_g = DistributedDataParallel(tmp_model_g.cuda(device))
if args.incremental_method != 'RCIL':
trainer = Trainer(tmp_model_g, None, device=device, rank=rank,opts=args, step=current_step)
else:
trainer = Trainer_rcil(tmp_model_g, None, device=device, rank=rank,opts=args, step=current_step)
tmp_model_g.eval()
_, val_score, _ = trainer.validate(
loader=test_loader, metrics=val_metrics, end_task=True
)
if rank==0:
print(val_metrics.to_str(val_score))
tmp_model_g = tmp_model_g.to('cpu')
torch.cuda.empty_cache()
del tmp_model_g
del trainer
return val_score