-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
57 lines (49 loc) · 2.56 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import torch
def compute_local_test_accuracy(model, dataloader, data_distribution):
model.eval()
total_label_num = np.zeros(len(data_distribution))
correct_label_num = np.zeros(len(data_distribution))
model.cuda()
generalized_total, generalized_correct = 0, 0
with torch.no_grad():
for batch_idx, (x, target) in enumerate(dataloader):
x, target = x.cuda(), target.to(dtype=torch.int64).cuda()
out = model(x)
_, pred_label = torch.max(out.data, 1)
correct_filter = (pred_label == target.data)
generalized_total += x.data.size()[0]
generalized_correct += correct_filter.sum().item()
for i, true_label in enumerate(target.data):
total_label_num[true_label] += 1
if correct_filter[i]:
correct_label_num[true_label] += 1
personalized_correct = (correct_label_num * data_distribution).sum()
personalized_total = (total_label_num * data_distribution).sum()
model.to('cpu')
return personalized_correct / personalized_total, generalized_correct / generalized_total
def compute_acc(net, test_data_loader):
net.cuda()
net.eval()
correct, total = 0, 0
with torch.no_grad():
for batch_idx, (x, target) in enumerate(test_data_loader):
x, target = x.cuda(), target.to(dtype=torch.int64).cuda()
out = net(x)
_, pred_label = torch.max(out.data, 1)
total += x.data.size()[0]
correct += (pred_label == target.data).sum().item()
net.to('cpu')
return correct / float(total)
def evaluate_global_model(args, nets_this_round, global_model, val_local_dls, test_dl, data_distributions, best_val_acc_list, best_test_acc_list, benign_client_list):
for net_id, _ in nets_this_round.items():
if net_id in benign_client_list:
val_local_dl = val_local_dls[net_id]
data_distribution = data_distributions[net_id]
val_acc = compute_acc(global_model, val_local_dl)
personalized_test_acc, generalized_test_acc = compute_local_test_accuracy(global_model, test_dl, data_distribution)
if val_acc > best_val_acc_list[net_id]:
best_val_acc_list[net_id] = val_acc
best_test_acc_list[net_id] = personalized_test_acc
print('>> Client {} | Personalized Test Acc: {:.5f} | Generalized Test Acc: {:.5f}'.format(net_id, personalized_test_acc, generalized_test_acc))
return np.array(best_test_acc_list)[np.array(benign_client_list)].mean()