-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
131 lines (94 loc) · 4.82 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import torch
import argparse
from trainers.test import test_model
from utils.log_helper import init_log, add_file_handler
import logging
from utils.flops_counter import add_flops_counting_methods, flops_to_string, get_model_parameters_number
import models
import importlib
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description='PyTorch VOS Training')
# /media/hyojin/SSD1TB/Dataset
#data_loading
parser.add_argument('--cache_dir', default='./cache', type=str, help='cache dir')
parser.add_argument('--Ytb_dir', default='/media/hyojin/SSD1TB1/Dataset/Youtube-VOS2019', type=str, help='Ytb dataset dir')
parser.add_argument('--Davis_dir', default='/media/hyojin/SSD1TB1/Dataset/DAVIS', type=str, help='Davis dataset dir')
parser.add_argument('--nnWeight', default='./nnWeight', type=str, help='nn_weights_path')
#Model Strtucture
parser.add_argument('--backbone', default='hrnetv2Sv1', choices=['resnet50s16', 'resnet18s16',
'mobileNetV3Larges16','hrnetv2Sv1'],
help='architecture of backbone model')
parser.add_argument('--refine', default='v3', choices=['v1', 'v2', 'v3'])
parser.add_argument('--DiffLoss', default=True, choices=[True, False])
parser.add_argument('--TmpLoss', default=True, choices=[True, False])
parser.add_argument('--LT', default='slim', choices=['Not', 'gc', 'slim'])
parser.add_argument('--ltG', default=4, type=int, help="number of group in slim")
parser.add_argument('--ltLoc', default='s8', type=str, help="input feature for for long-term")
# Etc
parser.add_argument('--parallel', default=False, type=bool, help='use parallel')
parser.add_argument("--debug",type=bool, default=False,help="Vis out detail feature map")
parser.add_argument("--log",type=str, default="logTest.txt",help="Out log file")
parser.add_argument("--TestInfo", default="False",help="show detail eval result")
parser.add_argument('--save_dir', default='./save_dir/', type=str, help='save dir')
parser.add_argument("--pth", default="",help="get pth file")#RN18
def main():
args = parser.parse_args()
init_log('global', logging.INFO)
######## str to bool ##########################
args.TestInfo = (args.TestInfo =="True")
logger = logging.getLogger('global')
if args.backbone == "resnet50s16":
_model = models.VOS(backbone=('resnet50s16', (True, ('layer3',), ('layer3',), ('layer2',), ('layer1',),
args.nnWeight)), mode='eval', args=args)
elif "hrnet" in args.backbone :
_model = models.VOS(backbone=(args.backbone, (True, ('stage4',),args.nnWeight)), mode='eval', args=args)
elif args.backbone == "resnet18s16" :
_model = models.VOS(backbone=('resnet18s16', (True, ('layer4',), ('layer4',), ('layer2',), ('layer1',),
args.nnWeight)), mode='eval', args=args)
elif "mobileNetV3" in args.backbone :
_model = models.VOS(backbone=(args.backbone, (True, ('layer4',), ('layer4',), ('layer2',), ('layer1',),
args.nnWeight)), mode='eval', args=args)
# print(_model)
logger.info(args)
logger.info("Start training ...")
logger.info("Using : ")
logger.info(" Backbone : " + (args.backbone))
logger.info(" Refine : " + (args.refine))
logger.info(" LT : " + (args.LT))
x = torch.rand(1, 2, 3, 480, 854)
gt = torch.zeros(1, 1, 480, 854)
gt[:, :, 100:200, 100:200] = 1
gt = gt.long()
gt_set = [gt, None]
_model.train(False)
model_eval = add_flops_counting_methods(_model)
model_eval.start_flops_count()
model_eval(x, gt_set, None)
N_flop = _model.compute_average_flops_cost()
logger.info("input size is : {}".format(str(x.size())))
logger.info('Flops: {}'.format(flops_to_string(N_flop)))
logger.info('Params: ' + get_model_parameters_number(_model))
_model.to(DEVICE)
if args.pth == "":
pthname = None
else:
pthname = os.path.join(args.save_dir,args.pth)
args.save_dir = os.path.join(args.save_dir,"DEBUG")
if not os.path.isdir(args.save_dir):
os.mkdir(args.save_dir)
args.log = os.path.join(args.save_dir, args.log)
if args.log != "":
add_file_handler('global', args.log, logging.INFO)
model = _model
if pthname is not None:
model.load_state_dict(torch.load(pthname, map_location=torch.device(DEVICE))['net'])
logger.info("Take pth file : {}".format(pthname))
if torch.cuda.device_count() > 1 and args.parallel:
model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
else:
model = model.cuda()
test_model(model, 0, args)
if __name__ == '__main__':
main()