-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
126 lines (105 loc) · 5.64 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
import config
import main
def Mytest(helper, epoch,
model, is_poison=False, visualize=True, agent_name_key=""):
model.eval()
total_loss = 0
correct = 0
dataset_size = 0
data_iterator = helper.test_data
for batch_id, batch in enumerate(data_iterator):
data, targets = helper.get_batch(data_iterator, batch, evaluation=True)
dataset_size += len(data)
output = model(data)
total_loss += nn.functional.cross_entropy(output, targets,
reduction='sum').item() # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(targets.data.view_as(pred)).cpu().sum().item()
acc = 100.0 * (float(correct) / float(dataset_size)) if dataset_size!=0 else 0
total_l = total_loss / dataset_size if dataset_size!=0 else 0
main.logger.info('___Test {} 是否注入后门: {}, 当前轮次: {}: 平均损失: {:.4f}, '
'准确率: {}/{} ({:.4f}%)'.format(model.name, is_poison, epoch,
total_l, correct, dataset_size,
acc))
if visualize: # loss =total_l
model.test_vis(vis=main.vis, epoch=epoch, acc=acc, loss=None,
eid=helper.params['environment_name'],
agent_name_key=str(agent_name_key))
model.train()
return (total_l, acc, correct, dataset_size)
def Mytest_poison(helper, epoch,
model, is_poison=False, visualize=True, agent_name_key="",test_type=None):
model.eval()
total_loss = 0.0
correct = 0
dataset_size = 0
poison_data_count = 0
data_iterator = helper.test_data_poison
for batch_id, batch in enumerate(data_iterator):
data, targets, poison_num = helper.get_poison_batch(batch, adversarial_index=-1, evaluation=True,test_type=test_type)
poison_data_count += poison_num
dataset_size += len(data)
output = model(data)
total_loss += nn.functional.cross_entropy(output, targets,
reduction='sum').item() # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(targets.data.view_as(pred)).cpu().sum().item()
main.logger.info('{} {}'.format(float(correct),poison_data_count))
acc = 100.0 * (float(correct) / float(poison_data_count)) if poison_data_count!=0 else 0
total_l = total_loss / poison_data_count if poison_data_count!=0 else 0
main.logger.info('___Test {} 是否注入后门: {}, 当前轮次: {}: 平均损失: {:.4f}, '
'准确率: {}/{} ({:.4f}%)'.format(model.name, is_poison, epoch,
total_l, correct, poison_data_count,
acc))
if visualize: #loss = total_l
model.poison_test_vis(vis=main.vis, epoch=epoch, acc=acc, loss=None, eid=helper.params['environment_name'],agent_name_key=str(agent_name_key))
model.train()
return total_l, acc, correct, poison_data_count
def Mytest_poison_trigger(helper, model, adver_trigger_index,test_type=None):
model.eval()
total_loss = 0.0
correct = 0
dataset_size = 0
poison_data_count = 0
data_iterator = helper.test_data_poison
adv_index = adver_trigger_index
for batch_id, batch in enumerate(data_iterator):
data, targets, poison_num = helper.get_poison_batch(batch, adversarial_index=adv_index, evaluation=True,test_type=test_type)
poison_data_count += poison_num
dataset_size += len(data)
output = model(data)
total_loss += nn.functional.cross_entropy(output, targets,
reduction='sum').item() # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(targets.data.view_as(pred)).cpu().sum().item()
acc = 100.0 * (float(correct) / float(poison_data_count)) if poison_data_count!=0 else 0
total_l = total_loss / poison_data_count if poison_data_count!=0 else 0
model.train()
return total_l, acc, correct, poison_data_count
def Mytest_poison_agent_trigger(helper, model, agent_name_key):
model.eval()
total_loss = 0.0
correct = 0
dataset_size = 0
poison_data_count = 0
data_iterator = helper.test_data_poison
adv_index = -1
for temp_index in range(0, len(helper.params['total_list'])):
if int(agent_name_key) == helper.params['total_list'][temp_index]:
adv_index = temp_index
break
for batch_id, batch in enumerate(data_iterator):
data, targets, poison_num = helper.get_poison_batch(batch, adversarial_index=adv_index, evaluation=True)
poison_data_count += poison_num
dataset_size += len(data)
output = model(data)
total_loss += nn.functional.cross_entropy(output, targets,
reduction='sum').item() # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(targets.data.view_as(pred)).cpu().sum().item()
acc = 100.0 * (float(correct) / float(poison_data_count)) if poison_data_count!=0 else 0
total_l = total_loss / poison_data_count if poison_data_count!=0 else 0
model.train()
return total_l, acc, correct, poison_data_count