Skip to content

Commit

Permalink
Add infer scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
qingqing01 committed Jun 6, 2018
1 parent b6c505b commit 440642c
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 44 deletions.
105 changes: 105 additions & 0 deletions fluid/face_detection/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import time
import numpy as np
import argparse
import functools
from PIL import Image
from PIL import ImageDraw

import paddle
import paddle.fluid as fluid
import reader
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
add_arg('confs_threshold', float, 0.2, "Confidence threshold to draw bbox.")
add_arg('image_path', str, '', "The data root path.")
add_arg('model_dir', str, '', "The model path.")
add_arg('resize_h', int, 0, "The resized image height.")
add_arg('resize_w', int, 0, "The resized image height.")
# yapf: enable


def draw_bounding_box_on_image(image_path, nms_out, confs_threshold):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size

for dt in nms_out:
category_id, score, xmin, ymin, xmax, ymax = dt.tolist()
if score < confs_threshold:
continue
bbox = dt[2:]
xmin, ymin, xmax, ymax = bbox
(left, right, top, bottom) = (xmin * im_width, xmax * im_width,
ymin * im_height, ymax * im_height)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=4,
fill='red')
image_name = image_path.split('/')[-1]
print("image with bbox drawed saved as {}".format(image_name))
image.save(image_name)


def infer(args, data_args):
num_classes = 2
infer_reader = reader.infer(data_args, args.image_path)
data = infer_reader()

# TODO(qingqing): support variable-length
if args.resize_h and args.resize_w:
image_shape = [3, args.resize_h, args.resize_w]
else:
image_shape = data.shape[1:]

fetches = []

network = PyramidBox(
image_shape, sub_network=args.use_pyramidbox, is_infer=True)
infer_program, nmsed_out = network.infer()
fetches = [nmsed_out]

place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)

model_dir = args.model_dir
if not os.path.exists(model_dir):
raise ValueError("The model path [%s] does not exist." % (model_dir))

def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))

fluid.io.load_vars(exe, model_dir, predicate=if_exist)

feed = {'image': fluid.create_lod_tensor(data, [], place)}
predict, = exe.run(infer_program,
feed=feed,
fetch_list=fetches,
return_numpy=False)
predict = np.array(predict)
draw_bounding_box_on_image(args.image_path, predict, args.confs_threshold)


if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)

data_dir = 'data/WIDERFACE/WIDER_val/images/'
file_list = 'label/val_gt_widerface.res'

data_args = reader.Settings(
data_dir=data_dir,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[104., 117., 123],
apply_distort=False,
apply_expand=False,
ap_version='11point')
infer(args, data_args=data_args)
75 changes: 35 additions & 40 deletions fluid/face_detection/pyramidbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,17 @@ def conv_block(input, groups, filters, ksizes, strides=None, with_pool=True):


class PyramidBox(object):
def __init__(self, data_shape, is_infer=False, sub_network=False):
def __init__(self,
data_shape,
num_classes,
is_infer=False,
sub_network=False):
self.data_shape = data_shape
self.min_sizes = [16., 32., 64., 128., 256., 512.]
self.steps = [4., 8., 16., 32., 64., 128.]
self.is_infer = is_infer
self.sub_network = sub_network
self.num_classes = num_classes

# the base network is VGG with atrous layers
self._input()
Expand All @@ -59,6 +64,8 @@ def __init__(self, data_shape, is_infer=False, sub_network=False):
self._low_level_fpn()
self._cpm_module()
self._pyramidbox()
else:
self._vgg_ssd()

def feeds(self):
if self.is_infer:
Expand Down Expand Up @@ -188,9 +195,10 @@ def _pyramidbox(self):
"""
Get prior-boxes and pyramid-box
"""
self.ssh_conv3_norm = self._l2_norm_scale(self.ssh_conv3)
self.ssh_conv4_norm = self._l2_norm_scale(self.ssh_conv4)
self.ssh_conv5_norm = self._l2_norm_scale(self.ssh_conv5)
self.ssh_conv3_norm = self._l2_norm_scale(
self.ssh_conv3, init_scale=10.)
self.ssh_conv4_norm = self._l2_norm_scale(self.ssh_conv4, init_scale=8.)
self.ssh_conv5_norm = self._l2_norm_scale(self.ssh_conv5, init_scale=5.)

def permute_and_reshape(input, last_dim):
trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1])
Expand Down Expand Up @@ -253,34 +261,41 @@ def permute_and_reshape(input, last_dim):
self.prior_boxes = fluid.layers.concat(boxes)
self.box_vars = fluid.layers.concat(vars)

def vgg_ssd(self, num_classes, image_shape):
self.conv3_norm = self._l2_norm_scale(self.conv3)
self.conv4_norm = self._l2_norm_scale(self.conv4)
self.conv5_norm = self._l2_norm_scale(self.conv5)
def _vgg_ssd(self):
self.conv3_norm = self._l2_norm_scale(self.conv3, init_scale=10.)
self.conv4_norm = self._l2_norm_scale(self.conv4, init_scale=8.)
self.conv5_norm = self._l2_norm_scale(self.conv5, init_scale=5.)

mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[
self.conv3_norm, self.conv4_norm, self.conv5_norm, self.conv6,
self.conv7, self.conv8
],
image=self.image,
num_classes=num_classes,
# min_ratio=20,
# max_ratio=90,
num_classes=self.num_classes,
min_sizes=[16.0, 32.0, 64.0, 128.0, 256.0, 512.0],
max_sizes=[[], [], [], [], [], []],
# max_sizes=[[], 150.0, 195.0, 240.0, 285.0, 300.0],
aspect_ratios=[[1.], [1.], [1.], [1.], [1.], [1.]],
steps=[4.0, 8.0, 16.0, 32.0, 64.0, 128.0],
base_size=image_shape[2],
base_size=self.data_shape[2],
offset=0.5,
flip=False)

# locs, confs, box, box_var = vgg_extra_net(num_classes, image, image_shape)
# nmsed_out = fluid.layers.detection_output(
# locs, confs, box, box_var, nms_threshold=args.nms_threshold)
loss = fluid.layers.ssd_loss(mbox_locs, mbox_confs, self.face_box,
self.gt_label, box, box_var)
self.face_mbox_loc = mbox_locs
self.face_mbox_conf = mbox_confs
self.prior_boxes = box
self.box_vars = box_var

def vgg_ssd_loss(self):
loss = fluid.layers.ssd_loss(
self.face_mbox_loc,
self.face_mbox_conf,
self.face_box,
self.gt_label,
self.prior_boxes,
self.box_vars,
overlap_threshold=0.35,
neg_overlap=0.35)
loss = fluid.layers.reduce_sum(loss)

return loss
Expand All @@ -297,7 +312,7 @@ def train(self):
total_loss = face_loss + head_loss
return face_loss, head_loss, total_loss

def test(self):
def infer(self):
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
face_nmsed_out = fluid.layers.detection_output(
Expand All @@ -306,24 +321,4 @@ def test(self):
self.prior_boxes,
self.box_vars,
nms_threshold=0.45)
head_nmsed_out = fluid.layers.detection_output(
self.head_mbox_loc,
self.head_mbox_conf,
self.prior_boxes,
self.box_vars,
nms_threshold=0.45)
face_map_eval = fluid.evaluator.DetectionMAP(
face_nmsed_out,
self.gt_label,
self.face_box,
class_num=2,
overlap_threshold=0.5,
ap_version='11point')
head_map_eval = fluid.evaluator.DetectionMAP(
head_nmsed_out,
self.gt_label,
self.head_box,
class_num=2,
overlap_threshold=0.5,
ap_version='11point')
return test_program, face_map_eval, head_map_eval
return test_program, face_nmsed_out
26 changes: 26 additions & 0 deletions fluid/face_detection/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,29 @@ def reader():

def train(settings, file_list, shuffle=True):
return pyramidbox(settings, file_list, 'train', shuffle)


def infer(settings, image_path):
def batch_reader():
img = Image.open(image_path)
if img.mode == 'L':
img = im.convert('RGB')
im_width, im_height = img.size
if settings.resize_w and settings.resize_h:
img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
img = np.array(img)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img * 0.007843
img = [img]
img = np.array(img)
return img

return batch_reader
8 changes: 4 additions & 4 deletions fluid/face_detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
image_shape = [3, data_args.resize_h, data_args.resize_w]

fetches = []
network = PyramidBox(image_shape, num_classes,
sub_network=args.use_pyramidbox)
if args.use_pyramidbox:
network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
face_loss, head_loss, loss = network.train()
fetches = [face_loss, head_loss]
else:
network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
loss = network.vgg_ssd(num_classes, image_shape)
loss = network.vgg_ssd_loss()
fetches = [loss]

epocs = 12880 / batch_size
Expand Down Expand Up @@ -126,7 +126,7 @@ def save_model(postfix):
batch_id, fetch_vars[0], fetch_vars[1],
start_time - prev_start_time))

if pass_id % 10 == 0 or pass_id == num_passes - 1:
if pass_id % 1 == 0 or pass_id == num_passes - 1:
save_model(str(pass_id))


Expand Down

0 comments on commit 440642c

Please sign in to comment.