-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
78 lines (60 loc) · 2.31 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import argparse
from tqdm import tqdm
import yaml
from attrdict import AttrMap
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from data_manager import TestDataset
from utils import gpu_manage, save_image, heatmap
from models.gen.SPANet import Generator
def predict(config, args):
gpu_manage(args)
dataset = TestDataset(args.test_dir, config.in_ch, config.out_ch)
data_loader = DataLoader(dataset=dataset, num_workers=config.threads, batch_size=1, shuffle=False)
### MODELS LOAD ###
print('===> Loading models')
gen = Generator(gpu_ids=config.gpu_ids)
param = torch.load(args.pretrained)
gen.load_state_dict(param)
if args.cuda:
gen = gen.cuda(0)
with torch.no_grad():
for i, batch in enumerate(tqdm(data_loader)):
x = Variable(batch[0])
filename = batch[1][0]
if args.cuda:
x = x.cuda()
att, out = gen(x)
h = 1
w = 3
c = 3
p = config.width
allim = np.zeros((h, w, c, p, p))
x_ = x.cpu().numpy()[0]
out_ = out.cpu().numpy()[0]
in_rgb = x_[:3]
out_rgb = np.clip(out_[:3], 0, 1)
att_ = att.cpu().numpy()[0] * 255
heat_att = heatmap(att_.astype('uint8'))
allim[0, 0, :] = in_rgb * 255
allim[0, 1, :] = out_rgb * 255
allim[0, 2, :] = heat_att
allim = allim.transpose(0, 3, 1, 4, 2)
allim = allim.reshape((h*p, w*p, c))
save_image(args.out_dir, allim, i, 1, filename=filename)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--test_dir', type=str, required=False)
parser.add_argument('--out_dir', type=str, required=True)
parser.add_argument('--pretrained', type=str, required=True)
parser.add_argument('--cuda', action='store_true')
parser.add_argument('--gpu_ids', type=int, default=[0])
parser.add_argument('--manualSeed', type=int, default=0)
args = parser.parse_args()
with open(args.config, 'r', encoding='UTF-8') as f:
config = yaml.load(f)
config = AttrMap(config)
predict(config, args)