-
Notifications
You must be signed in to change notification settings - Fork 3
/
main_fed.py
95 lines (77 loc) · 3.3 KB
/
main_fed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import copy
import numpy as np
import torch
from models.Fed import FedAvg
from models.Nets import MLP, Mnistcnn
from models.Sia import SIA
from models.Update import LocalUpdate
from models.test import test_fun
from utils.dataset import get_dataset, exp_details
from utils.options import args_parser
if __name__ == '__main__':
# parse args
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
# load dataset and split data for users
dataset_train, dataset_test, dict_party_user, dict_sample_user = get_dataset(args)
# build model
if args.model == 'cnn' and args.dataset == 'MNIST':
net_glob = Mnistcnn(args=args).to(args.device)
elif args.model == 'mlp':
len_in = 1
dataset_train = dataset_train.dataset
dataset_test = dataset_test.dataset
img_size = dataset_train[0][0].shape
for x in img_size:
len_in *= x
net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
else:
exit('Error: unrecognized model')
empty_net = net_glob
print('Model architecture:')
print(net_glob)
net_glob.train()
# copy weights
w_glob = net_glob.state_dict()
# training
if args.all_clients:
print("Aggregation over all clients")
w_locals = [w_glob for i in range(args.num_users)]
best_att_acc = 0
for iter in range(args.epochs):
loss_locals = []
if not args.all_clients:
w_locals = []
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
for idx in idxs_users:
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_party_user[idx])
w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
if args.all_clients:
w_locals[idx] = copy.deepcopy(w)
else:
w_locals.append(copy.deepcopy(w))
loss_locals.append(copy.deepcopy(loss))
# implement the source inference attack
SIA_attack = SIA(args=args, w_locals=w_locals, dataset=dataset_train, dict_sia_users=dict_sample_user)
attack_acc = SIA_attack.attack(net=empty_net.to(args.device))
best_att_acc = max(best_att_acc, attack_acc)
# update global weights
w_glob = FedAvg(w_locals)
# copy weight to net_glob
net_glob.load_state_dict(w_glob)
acc_train, loss_train_ = test_fun(net_glob, dataset_train, args)
# print loss
loss_avg = sum(loss_locals) / len(loss_locals)
print('Round {:3d}, Average training loss {:.3f}'.format(iter, loss_avg))
# testing
net_glob.eval()
acc_train, loss_train_ = test_fun(net_glob, dataset_train, args)
acc_test, loss_test = test_fun(net_glob, dataset_test, args)
# experiment setting
exp_details(args)
print('Experimental result summary:')
print("Training accuracy of the joint model: {:.2f}".format(acc_train))
print("Testing accuracy of the joint model: {:.2f}".format(acc_test))
print('Random guess baseline of source inference : {:.2f}'.format(1.0/args.num_users*100))
print('Highest prediction loss based source inference accuracy: {:.2f}'.format(best_att_acc))