-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathEMFINet_test.py
78 lines (56 loc) · 1.91 KB
/
EMFINet_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from skimage import io
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import glob
from Data_loader import RescaleT
from Data_loader import ToTensorLab
from Data_loader import SalObjDataset
from model import MYNet
from tqdm import tqdm
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def save_output(image_name,pred,d_dir):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
im = Image.fromarray(predict_np*255).convert('RGB')
img_name = image_name.split("/")[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
imo.save(d_dir+imidx+'.png')
image_dir = "/Image_test/"
prediction_dir = "/images_save/"
model_dir = "/model_save/EMFINet.pth"
img_name_list = glob.glob(image_dir + '*.jpg')
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, lbl_name_list = [],edge_name_list = [], transform=transforms.Compose([RescaleT(256),ToTensorLab(flag=0)]))
test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1,shuffle=False,num_workers=0)
net = MYNet()
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.cuda()
net.eval()
for i_test, data_test in tqdm(enumerate(test_salobj_dataloader)):
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
de,d1,d2,d3,d4,d5,d6,d7,d8,d9 = net(inputs_test)
# normalization
pred = d9[:,0,:,:]
pred = normPRED(pred)
# save results to test_results folder
save_output(img_name_list[i_test],pred,prediction_dir)
del de,d1,d2,d3,d4,d5,d6,d7,d8,d9