Skip to content

Commit

Permalink
remove negative samples during RCNN box and mask heads training (#981)
Browse files Browse the repository at this point in the history
* use new ops in apache/mxnet#16215

* sampler wrap around for last part

* reduce mask head num samples

* rm reshape

fix bugs

rm redundant comment

* revert rpn_channel

revert rpn_channel

revert some change

fix typo

typo

fix typo

* fix docs

fix

fix

fix

fix

fix

fix

fix docs

fix docs

docs

docs

* fix tutorial

* fix log

* fix learning rate
  • Loading branch information
Jerryzcn authored Oct 16, 2019
1 parent e1680e3 commit ec5bcec
Show file tree
Hide file tree
Showing 17 changed files with 284 additions and 160 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ The following commands install the stable version of GluonCV and MXNet:
```bash
pip install gluoncv --upgrade
pip install mxnet-mkl --upgrade
# if cuda 9.2 is installed
pip install mxnet-cu92mkl --upgrade
# if cuda 10.1 is installed
pip install mxnet-cu101mkl --upgrade
```

**The latest stable version of GluonCV is 0.4 and depends on mxnet >= 1.4.0**
Expand All @@ -66,8 +66,8 @@ You may get access to latest features and bug fixes with the following commands
```bash
pip install gluoncv --pre --upgrade
pip install mxnet-mkl --pre --upgrade
# if cuda 9.2 is installed
pip install mxnet-cu92mkl --pre --upgrade
# if cuda 10.1 is installed
pip install mxnet-cu101mkl --pre --upgrade
```

There are multiple versions of MXNet pre-built package available. Please refer to [mxnet packages](https://gluon-crash-course.mxnet.io/mxnet_packages.html) if you need more details about MXNet versions.
Expand Down
18 changes: 7 additions & 11 deletions docs/tutorials/detection/train_faster_rcnn_voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@
with autograd.train_mode():
# this time we need ground-truth to generate high quality roi proposals during training
gt_box = mx.nd.zeros(shape=(1, 1, 4))
cls_preds, box_preds, roi, samples, matches, rpn_score, rpn_box, anchors = net(x, gt_box)
gt_label = mx.nd.zeros(shape=(1, 1, 1))
cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
box_targets, box_masks, _ = net(x, gt_box, gt_label)

##############################################################################
# In training mode, Faster-RCNN returns a lot of intermediate values, which we require to train in an end-to-end flavor,
Expand Down Expand Up @@ -272,11 +274,8 @@
gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4]
# network forward
cls_preds, box_preds, roi, samples, matches, rpn_score, rpn_box, anchors = net(
data.expand_dims(0), gt_box)
# generate targets for rcnn
cls_targets, box_targets, box_masks = net.target_generator(roi, samples, matches,
gt_label, gt_box)
cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
box_targets, box_masks, _ = net(data.expand_dims(0), gt_box, gt_label)

print('data:', data.shape)
# box and class labels
Expand All @@ -302,11 +301,8 @@
gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4]
# network forward
cls_preds, box_preds, roi, samples, matches, rpn_score, rpn_box, anchors = net(
data.expand_dims(0), gt_box)
# generate targets for rcnn
cls_targets, box_targets, box_masks = net.target_generator(roi, samples, matches,
gt_label, gt_box)
cls_preds, box_preds, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
box_targets, box_masks, _ = net(data.expand_dims(0), gt_box, gt_label)

# losses of rpn
rpn_score = rpn_score.squeeze(axis=-1)
Expand Down
46 changes: 32 additions & 14 deletions docs/tutorials/instance/train_mask_rcnn_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,9 @@
with autograd.train_mode():
# this time we need ground-truth to generate high quality roi proposals during training
gt_box = mx.nd.zeros(shape=(1, 1, 4))
cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors = net(x,
gt_box)
gt_label = mx.nd.zeros(shape=(1, 1, 1))
cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \
cls_targets, box_targets, box_masks, indices = net(x, gt_box, gt_label)

##########################################################
# Training losses
Expand Down Expand Up @@ -260,14 +261,23 @@
gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4]
# network forward
cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors = \
net(data.expand_dims(0), gt_box)
# generate targets for rcnn
cls_targets, box_targets, box_masks = net.target_generator(roi, samples, matches,
gt_label, gt_box)
cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \
cls_targets, box_targets, box_masks, indices = \
net(data.expand_dims(0), gt_box, gt_label)

# generate targets for mask head
roi = mx.nd.concat(
*[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1, 4))
m_cls_targets = mx.nd.concat(
*[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
matches = mx.nd.concat(
*[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
mask_targets, mask_masks = net.mask_target(roi, masks.expand_dims(0), matches,
cls_targets)
m_cls_targets)

print('data:', data.shape)
# box and class labels
print('box:', gt_box.shape)
Expand Down Expand Up @@ -299,14 +309,22 @@
gt_label = label[:, :, 4:5]
gt_box = label[:, :, :4]
# network forward
cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors = \
net(data.expand_dims(0), gt_box)
# generate targets for rcnn
cls_targets, box_targets, box_masks = net.target_generator(roi, samples, matches,
gt_label, gt_box)
cls_preds, box_preds, mask_preds, roi, samples, matches, rpn_score, rpn_box, anchors, \
cls_targets, box_targets, box_masks, indices = \
net(data.expand_dims(0), gt_box, gt_label)

# generate targets for mask head
roi = mx.nd.concat(
*[mx.nd.take(roi[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1, 4))
m_cls_targets = mx.nd.concat(
*[mx.nd.take(cls_targets[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
matches = mx.nd.concat(
*[mx.nd.take(matches[i], indices[i]) for i in range(indices.shape[0])], dim=0) \
.reshape((indices.shape[0], -1))
mask_targets, mask_masks = net.mask_target(roi, masks.expand_dims(0), matches,
cls_targets)
m_cls_targets)

# losses of rpn
rpn_score = rpn_score.squeeze(axis=-1)
Expand Down
60 changes: 41 additions & 19 deletions gluoncv/model_zoo/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def __init__(self, features, top_features, classes, box_features=None,
self._batch_size = per_device_batch_size
self._num_sample = num_sample
self._rpn_test_post_nms = rpn_test_post_nms
self._target_generator = RCNNTargetGenerator(self.num_class)
self._target_generator = RCNNTargetGenerator(self.num_class, int(num_sample * pos_ratio),
self._batch_size)
self._additional_output = additional_output
with self.name_scope():
self.rpn = RPN(
Expand All @@ -207,7 +208,7 @@ def __init__(self, features, top_features, classes, box_features=None,
clip=clip, nms_thresh=rpn_nms_thresh, train_pre_nms=rpn_train_pre_nms,
train_post_nms=rpn_train_post_nms, test_pre_nms=rpn_test_pre_nms,
test_post_nms=rpn_test_post_nms, min_size=rpn_min_size,
multi_level=self.num_stages > 1)
multi_level=self.num_stages > 1, per_level_nms=False)
self.sampler = RCNNTargetSampler(num_image=self._batch_size,
num_proposal=rpn_train_post_nms, num_sample=num_sample,
pos_iou_thresh=pos_iou_thresh, pos_ratio=pos_ratio,
Expand Down Expand Up @@ -252,7 +253,8 @@ def reset_class(self, classes, reuse_weights=None):
"""
super(FasterRCNN, self).reset_class(classes, reuse_weights)
self._target_generator = RCNNTargetGenerator(self.num_class)
self._target_generator = RCNNTargetGenerator(self.num_class, self.sampler._max_pos,
self._batch_size)

def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode='align',
roi_canonical_scale=224.0, eps=1e-6):
Expand Down Expand Up @@ -292,16 +294,25 @@ def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode=
# rpn_rois = F.take(rpn_rois, roi_level_sorted_args, axis=0)
pooled_roi_feats = []
for i, l in enumerate(range(self._min_stage, max_stage + 1)):
# Pool features with all rois first, and then set invalid pooled features to zero,
# at last ele-wise add together to aggregate all features.
if roi_mode == 'pool':
# Pool features with all rois first, and then set invalid pooled features to zero,
# at last ele-wise add together to aggregate all features.
pooled_feature = F.ROIPooling(features[i], rpn_rois, roi_size, 1. / strides[i])
pooled_feature = F.where(roi_level == l, pooled_feature,
F.zeros_like(pooled_feature))
elif roi_mode == 'align':
pooled_feature = F.contrib.ROIAlign(features[i], rpn_rois, roi_size,
1. / strides[i], sample_ratio=2)
if 'box_encode' in F.contrib.__dict__ and 'box_decode' in F.contrib.__dict__:
# TODO(jerryzcn): clean this up for once mx 1.6 is released.
masked_rpn_rois = F.where(roi_level == l, rpn_rois, F.ones_like(rpn_rois) * -1.)
pooled_feature = F.contrib.ROIAlign(features[i], masked_rpn_rois, roi_size,
1. / strides[i], sample_ratio=2)
else:
pooled_feature = F.contrib.ROIAlign(features[i], rpn_rois, roi_size,
1. / strides[i], sample_ratio=2)
pooled_feature = F.where(roi_level == l, pooled_feature,
F.zeros_like(pooled_feature))
else:
raise ValueError("Invalid roi mode: {}".format(roi_mode))
pooled_feature = F.where(roi_level == l, pooled_feature, F.zeros_like(pooled_feature))
pooled_roi_feats.append(pooled_feature)
# Ele-wise add to aggregate all pooled features
pooled_roi_feats = F.ElementWiseSum(*pooled_roi_feats)
Expand All @@ -312,7 +323,7 @@ def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode=
return pooled_roi_feats

# pylint: disable=arguments-differ
def hybrid_forward(self, F, x, gt_box=None):
def hybrid_forward(self, F, x, gt_box=None, gt_label=None):
"""Forward Faster-RCNN network.
The behavior during training and inference is different.
Expand All @@ -322,7 +333,9 @@ def hybrid_forward(self, F, x, gt_box=None):
x : mxnet.nd.NDArray or mxnet.symbol
The network input tensor.
gt_box : type, only required during training
The ground-truth bbox tensor with shape (1, N, 4).
The ground-truth bbox tensor with shape (B, N, 4).
gt_label : type, only required during training
The ground-truth label tensor with shape (B, 1, 4).
Returns
-------
Expand Down Expand Up @@ -385,20 +398,29 @@ def _split(x, axis, num_outputs, squeeze_axis):
else:
box_feat = self.box_features(top_feat)
cls_pred = self.class_predictor(box_feat)
box_pred = self.box_predictor(box_feat)
# cls_pred (B * N, C) -> (B, N, C)
cls_pred = cls_pred.reshape((batch_size, num_roi, self.num_class + 1))
# box_pred (B * N, C * 4) -> (B, N, C, 4)
box_pred = box_pred.reshape((batch_size, num_roi, self.num_class, 4))

# no need to convert bounding boxes in training, just return
if autograd.is_training():
cls_targets, box_targets, box_masks, indices = \
self._target_generator(rpn_box, samples, matches, gt_label, gt_box)
box_feat = F.reshape(box_feat.expand_dims(0), (batch_size, -1, 0))
box_pred = self.box_predictor(F.concat(
*[F.take(F.slice_axis(box_feat, axis=0, begin=i, end=i + 1).squeeze(),
F.slice_axis(indices, axis=0, begin=i, end=i + 1).squeeze())
for i in range(batch_size)], dim=0))
# box_pred (B * N, C * 4) -> (B, N, C, 4)
box_pred = box_pred.reshape((batch_size, -1, self.num_class, 4))
if self._additional_output:
return (cls_pred, box_pred, rpn_box, samples, matches,
raw_rpn_score, raw_rpn_box, anchors, top_feat)
return (cls_pred, box_pred, rpn_box, samples, matches,
raw_rpn_score, raw_rpn_box, anchors)
return (cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box,
anchors, cls_targets, box_targets, box_masks, top_feat, indices)
return (cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box,
anchors, cls_targets, box_targets, box_masks, indices)

box_pred = self.box_predictor(box_feat)
# box_pred (B * N, C * 4) -> (B, N, C, 4)
box_pred = box_pred.reshape((batch_size, num_roi, self.num_class, 4))
# cls_ids (B, N, C), scores (B, N, C)
cls_ids, scores = self.cls_decoder(F.softmax(cls_pred, axis=-1))
# cls_ids, scores (B, N, C) -> (B, C, N) -> (B, C, N, 1)
Expand All @@ -419,7 +441,7 @@ def _split(x, axis, num_outputs, squeeze_axis):
results = []
for rpn_box, cls_id, score, box_pred in zip(rpn_boxes, cls_ids, scores, box_preds):
# box_pred (C, N, 4) rpn_box (1, N, 4) -> bbox (C, N, 4)
bbox = self.box_decoder(box_pred, self.box_to_center(rpn_box))
bbox = self.box_decoder(box_pred, rpn_box)
# res (C, N, 6)
res = F.concat(*[cls_id, score, bbox], dim=-1)
if self.force_nms:
Expand Down Expand Up @@ -683,7 +705,7 @@ def faster_rcnn_fpn_bn_resnet50_v1b_coco(pretrained=False, pretrained_base=True,
top_features = None
# 1 Conv 1 FC layer before RCNN cls and reg
box_features = nn.HybridSequential()
box_features.add(nn.Conv2D(256, 3, padding=1),
box_features.add(nn.Conv2D(256, 3, padding=1, use_bias=False),
SyncBatchNorm(**gluon_norm_kwargs),
nn.Activation('relu'),
nn.Dense(1024, weight_initializer=mx.init.Normal(0.01)),
Expand Down
33 changes: 22 additions & 11 deletions gluoncv/model_zoo/faster_rcnn/rcnn_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def hybrid_forward(self, F, rois, scores, gt_boxes):
Parameters
----------
rois: (B, self._num_input, 4) encoded in (x1, y1, x2, y2).
scores: (B, self._num_input, 1), value range [0, 1] with ignore value -1.
rois: (B, self._num_proposal, 4) encoded in (x1, y1, x2, y2).
scores: (B, self._num_proposal, 1), value range [0, 1] with ignore value -1.
gt_boxes: (B, M, 4) encoded in (x1, y1, x2, y2), invalid box should have area of 0.
Returns
Expand All @@ -65,7 +65,7 @@ def hybrid_forward(self, F, rois, scores, gt_boxes):
roi = F.squeeze(F.slice_axis(rois, axis=0, begin=i, end=i + 1), axis=0)
score = F.squeeze(F.slice_axis(scores, axis=0, begin=i, end=i + 1), axis=0)
gt_box = F.squeeze(F.slice_axis(gt_boxes, axis=0, begin=i, end=i + 1), axis=0)
gt_score = F.ones_like(F.sum(gt_box, axis=-1, keepdims=True))
gt_score = F.sign(F.sum(gt_box, axis=-1, keepdims=True) + 1)

# concat rpn roi with ground truth. mix gt with generated boxes.
all_roi = F.concat(roi, gt_box, dim=0)
Expand Down Expand Up @@ -126,9 +126,13 @@ def hybrid_forward(self, F, rois, scores, gt_boxes):
samples = F.concat(topk_samples, bottomk_samples, dim=0)
matches = F.concat(topk_matches, bottomk_matches, dim=0)

new_rois.append(all_roi.take(indices))
new_samples.append(samples)
new_matches.append(matches)
sampled_rois = all_roi.take(indices)
x1, y1, x2, y2 = F.split(sampled_rois, axis=-1, num_outputs=4, squeeze_axis=True)
rois_area = (x2 - x1) * (y2 - y1)
ind = F.argsort(rois_area)
new_rois.append(sampled_rois.take(ind))
new_samples.append(samples.take(ind))
new_matches.append(matches.take(ind))
# stack all samples together
new_rois = F.stack(*new_rois, axis=0)
new_samples = F.stack(*new_samples, axis=0)
Expand All @@ -143,18 +147,24 @@ class RCNNTargetGenerator(gluon.HybridBlock):
----------
num_class : int
Number of total number of positive classes.
max_pos : int, default is 128
Upper bound of Number of positive samples.
per_device_batch_size : int, default is 1
Per device batch size
means : iterable of float, default is (0., 0., 0., 0.)
Mean values to be subtracted from regression targets.
stds : iterable of float, default is (.1, .1, .2, .2)
Standard deviations to be divided from regression targets.
"""

def __init__(self, num_class, means=(0., 0., 0., 0.), stds=(.1, .1, .2, .2)):
def __init__(self, num_class, max_pos=128, per_device_batch_size=1, means=(0., 0., 0., 0.),
stds=(.1, .1, .2, .2)):
super(RCNNTargetGenerator, self).__init__()
self._cls_encoder = MultiClassEncoder()
self._box_encoder = NormalizedPerClassBoxCenterEncoder(
num_class=num_class, means=means, stds=stds)
num_class=num_class, max_pos=max_pos, per_device_batch_size=per_device_batch_size,
means=means, stds=stds)

# pylint: disable=arguments-differ, unused-argument
def hybrid_forward(self, F, roi, samples, matches, gt_label, gt_box):
Expand All @@ -179,6 +189,7 @@ def hybrid_forward(self, F, roi, samples, matches, gt_label, gt_box):
# cls_target (B, N)
cls_target = self._cls_encoder(samples, matches, gt_label)
# box_target, box_weight (C, B, N, 4)
box_target, box_mask = self._box_encoder(
samples, matches, roi, gt_label, gt_box)
return cls_target, box_target, box_mask
box_target, box_mask, indices = self._box_encoder(samples, matches, roi, gt_label,
gt_box)

return cls_target, box_target, box_mask, indices
Loading

0 comments on commit ec5bcec

Please sign in to comment.