-
Notifications
You must be signed in to change notification settings - Fork 4
/
test_transformer.py
102 lines (67 loc) · 3.08 KB
/
test_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from sklearn.metrics import auc, roc_curve, precision_recall_curve
import numpy as np
def test(dataloader, model, args, device):
with torch.no_grad():
model.eval()
pred = torch.zeros(0)
gt = list(np.load('gt-colon.npy'))
start_ind = 0
gt_abn = []
embedding_overall = torch.zeros(0)
for i, (input, filename) in enumerate(dataloader):
input = input.to(device)
filename = filename[0].split('.npy')[0].split('/')[-1]
input = input.squeeze(2)
pred_temp = torch.zeros(0)
len_num_seg = input.shape[1]
# print(len_num_seg)
embedding_temp = torch.zeros(0)
for j in range(input.shape[1]//32+1):
start_idx = j * 32
end_idx = (j + 1)*32
# print(start_idx)
# print(end_idx)
input_tmp = input[:, start_idx:end_idx, :]
if input_tmp.shape[1] < 32:
for last in range((32-input_tmp.shape[1])):
input_tmp = torch.cat((input_tmp, input[:, -1, :].unsqueeze(1)), dim=1)
x, cls_tokens, cls_prob, scores, _, embeddings = model(input_tmp)
embeddings = embeddings.squeeze(0)
logits = torch.squeeze(scores, 2)
logits = torch.mean(logits, 0)
sig = logits
pred_temp = torch.cat((pred_temp, sig))
# embedding_temp = torch.cat((embedding_temp, embeddings), dim=0)
# logits = torch.squeeze(scores, 2)
# logits = torch.mean(logits, 0)
# sig = logits
# pred = torch.cat((pred, sig))
# print(start_ind)
pred = torch.cat((pred, pred_temp[:len_num_seg]))
# pred_plot = pred_temp[:len_num_seg].cpu().detach().numpy()
# pred_plot = np.repeat(np.array(pred_plot), 16)
# axes = plt.gca()
# axes.set_ylim([-0.05, 1.05])
# frames = np.arange(0, pred_plot.shape[0])
# plt.plot(frames, pred_plot, color='orange', linewidth=3)
# plt.xlabel('Frame Number', fontsize=15)
# plt.ylabel('Anomaly Score', fontsize=15)
# plt.grid(False)
# axes.xaxis.set_tick_params(labelsize=15)
# axes.yaxis.set_tick_params(labelsize=15)
# plt.savefig('plot_img/pred_' + str(filename) + '.png')
# plt.close()
# end_ind = start_ind+pred_plot.shape[0]
# gt_plot = gt[start_ind:end_ind]
# start_ind = end_ind
pred = list(pred.cpu().detach().numpy())
pred = np.repeat(np.array(pred), 16)
fpr, tpr, threshold = roc_curve(list(gt), pred)
rec_auc = auc(fpr, tpr)
precision, recall, th = precision_recall_curve(list(gt), pred)
pr_auc = auc(recall, precision)
print('auc : ' + str(rec_auc))
print('AP : ' + str(pr_auc))
print('pr_auc {}'.format(pr_auc))
return rec_auc