-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
75 lines (54 loc) · 2.19 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from skimage import io
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import glob
from tqdm import tqdm
from data_loader import *
from model import MINet
def normPRED(x):
MAX = torch.max(x)
MIN = torch.min(x)
out = (x - MIN) / (MAX - MIN)
return out
def save_output(image_dir, image_name, pred, save_dir):
predict = pred
predict = predict.squeeze()
predict = predict.cpu().data.numpy()
predict = Image.fromarray(predict * 255).convert('RGB')
image = io.imread(image_dir + image_name + '.bmp')
predict = predict.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR)
predict.save(save_dir + image_name + '.png')
if __name__ == '__main__':
# --------- Define the address and image format ---------
image_dir = "./Dataset/SD-saliency-900/Img_test/"
prediction_dir = "./results/"
model_dir = "./model_save/MINet.pth"
img_name_list = glob.glob(image_dir + '*.bmp')
# --------- Load the data ---------
test_salobj_dataset = SalObjDataset(img_name_list=img_name_list, lbl_name_list=[], transform=transforms.Compose([Rescale(368), ToTensor(flag=0)]))
test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=64, shuffle=False, num_workers=4)
# --------- Define the model ---------
print("...load MINet...")
net = MINet()
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.cuda()
net.eval()
# --------- Generate prediction images ---------
for i_test, data_test in tqdm(enumerate(test_salobj_dataloader)):
inputs_test, name_list = data_test['image'], data_test['name']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1, d2, d3, d4, d5 = net(inputs_test)
# normalization
for i in range(d1.shape[0]):
pred = d1[i, 0, :, :]
pred = normPRED(pred)
save_output(image_dir, name_list[i], pred, prediction_dir)
del d1, d2, d3, d4, d5