-
Notifications
You must be signed in to change notification settings - Fork 41
/
demo.py
executable file
·155 lines (146 loc) · 6.79 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import tools.find_mxnet
import mxnet as mx
import os
import sys
from detect.detector import Detector
from symbol.symbol_factory import get_symbol
def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, num_class,
nms_thresh=0.5, force_nms=True, nms_topk=400):
"""
wrapper for initialize a detector
Parameters:
----------
net : str
test network name
prefix : str
load model prefix
epoch : int
load model epoch
data_shape : int
resize image shape
mean_pixels : tuple (float, float, float)
mean pixel values (R, G, B)
ctx : mx.ctx
running context, mx.cpu() or mx.gpu(?)
num_class : int
number of classes
nms_thresh : float
non-maximum suppression threshold
force_nms : bool
force suppress different categories
"""
if net is not None:
if isinstance(data_shape, tuple):
data_shape = data_shape[0]
net = get_symbol(net, data_shape, num_classes=num_class, nms_thresh=nms_thresh,
force_nms=force_nms, nms_topk=nms_topk)
detector = Detector(net, prefix, epoch, data_shape, mean_pixels, ctx=ctx)
return detector
def parse_args():
parser = argparse.ArgumentParser(description='Single-shot detection network demo')
parser.add_argument('--network', dest='network', type=str, default='resnet50',
help='which network to use')
parser.add_argument('--images', dest='images', type=str, default='./data/demo/dog.jpg',
help='run demo with images, use comma to seperate multiple images')
parser.add_argument('--dir', dest='dir', nargs='?',
help='demo image directory, optional', type=str)
parser.add_argument('--ext', dest='extension', help='image extension, optional',
type=str, nargs='?')
parser.add_argument('--epoch', dest='epoch', help='epoch of trained model',
default=0, type=int)
parser.add_argument('--prefix', dest='prefix', help='trained model prefix',
default=os.path.join(os.getcwd(), 'model', 'ssd_'),
type=str)
parser.add_argument('--cpu', dest='cpu', help='(override GPU) use CPU to detect',
action='store_true', default=False)
parser.add_argument('--gpu', dest='gpu_id', type=int, default=0,
help='GPU device id to detect with')
parser.add_argument('--data-shape', dest='data_shape', type=str, default='512',
help='set image shape')
parser.add_argument('--mean-r', dest='mean_r', type=float, default=123,
help='red mean value')
parser.add_argument('--mean-g', dest='mean_g', type=float, default=117,
help='green mean value')
parser.add_argument('--mean-b', dest='mean_b', type=float, default=104,
help='blue mean value')
parser.add_argument('--thresh', dest='thresh', type=float, default=0.5,
help='object visualize score threshold, default 0.6')
parser.add_argument('--nms', dest='nms_thresh', type=float, default=0.5,
help='non-maximum suppression threshold, default 0.5')
parser.add_argument('--force', dest='force_nms', type=bool, default=True,
help='force non-maximum suppression on different class')
parser.add_argument('--timer', dest='show_timer', type=bool, default=True,
help='show detection time')
parser.add_argument('--deploy', dest='deploy_net', action='store_true', default=False,
help='Load network from json file, rather than from symbol')
parser.add_argument('--class-names', dest='class_names', type=str,
default='aeroplane, bicycle, bird, boat, bottle, bus, \
car, cat, chair, cow, diningtable, dog, horse, motorbike, \
person, pottedplant, sheep, sofa, train, tvmonitor',
help='string of comma separated names, or text filename')
args = parser.parse_args()
return args
def parse_class_names(class_names):
""" parse # classes and class_names if applicable """
if len(class_names) > 0:
if os.path.isfile(class_names):
# try to open it to read class names
with open(class_names, 'r') as f:
class_names = [l.strip() for l in f.readlines()]
else:
class_names = [c.strip() for c in class_names.split(',')]
for name in class_names:
assert len(name) > 0
else:
raise RuntimeError("No valid class_name provided...")
return class_names
def parse_data_shape(data_shape_str):
"""Parse string to tuple or int"""
ds = data_shape_str.strip().split(',')
if len(ds) == 1:
data_shape = (int(ds[0]), int(ds[0]))
elif len(ds) == 2:
data_shape = (int(ds[0]), int(ds[1]))
else:
raise ValueError("Unexpected data_shape: %s", data_shape_str)
return data_shape
if __name__ == '__main__':
args = parse_args()
if args.cpu:
ctx = mx.cpu()
else:
ctx = mx.gpu(args.gpu_id)
# parse image list
image_list = [i.strip() for i in args.images.split(',')]
assert len(image_list) > 0, "No valid image specified to detect"
network = None if args.deploy_net else args.network
class_names = parse_class_names(args.class_names)
data_shape = parse_data_shape(args.data_shape)
if args.prefix.endswith('_'):
prefix = args.prefix + args.network + '_' + str(data_shape[0])
else:
prefix = args.prefix
detector = get_detector(network, prefix, args.epoch,
data_shape,
(args.mean_r, args.mean_g, args.mean_b),
ctx, len(class_names), args.nms_thresh, args.force_nms)
# run detection
detector.detect_and_visualize(image_list, args.dir, args.extension,
class_names, args.thresh, args.show_timer)