-
Notifications
You must be signed in to change notification settings - Fork 11
/
test.py
93 lines (52 loc) · 1.83 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Code for testing on real SAR images
# Author: Malsha Perera
import argparse
import torch
import torchvision
from torch import nn
from torchvision.transforms import functional as F
import os
import numpy as np
import torch
from transform_main import TransSAR, TransSARV2, TransSARV3
import cv2
parser = argparse.ArgumentParser(description='TransSAR')
parser.add_argument('--cuda', default="on", type=str,
help='switch on/off cuda option (default: off)')
parser.add_argument('--load', default='default', type=str,
help='turn on img augmentation (default: default)')
parser.add_argument('--save_path', required=True , type=str,
help='turn on img augmentation (default: default)')
parser.add_argument('--model', type=str,
help='model name')
parser.add_argument('--crop', type=int, default=None)
parser.add_argument('--device', default='cuda', type=str)
parser.add_argument('--loadmodel', default='load', type=str)
args = parser.parse_args()
modelname = args.model
loaddirec = args.loadmodel
save_path = args.save_path
device = torch.device("cuda")
model = TransSARV2()
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model,device_ids=[0,1]).cuda()
model.to(device)
model.load_state_dict(torch.load(loaddirec))
model.eval()
if not os.path.isdir(save_path):
os.makedirs(save_path)
im_file = './test_images/test_01.png'
img = cv2.imread(im_file,0)
noisy_im = (np.float32(img)+1.0)/256.0
x = np.float32(noisy_im)
x = F.to_tensor(x)
x = x.unsqueeze(0)
pred_im = model(x)
tmp = pred_im.detach().cpu().numpy()
tmp = tmp.squeeze()
tmp = tmp*256 -1
filename_out = 'test_01_results.png'
filepath_out = save_path + filename_out
cv2.imwrite(filepath_out,tmp)
print('done')