Skip to content

Commit

Permalink
fix framework warning in solov2 (PaddlePaddle#2052)
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill authored Jan 13, 2021
1 parent 8dfbd86 commit 5afa83e
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 21 deletions.
7 changes: 6 additions & 1 deletion ppdet/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
import copy
import collections

try:
collectionsAbc = collections.abc
except AttributeError:
collectionsAbc = collections

from .config.schema import SchemaDict, SharedConfig, extract_schema
from .config.yaml_helpers import serializable

Expand Down Expand Up @@ -115,7 +120,7 @@ def dict_merge(dct, merge_dct):
"""
for k, v in merge_dct.items():
if (k in dct and isinstance(dct[k], dict) and
isinstance(merge_dct[k], collections.Mapping)):
isinstance(merge_dct[k], collectionsAbc.Mapping)):
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
Expand Down
16 changes: 9 additions & 7 deletions ppdet/modeling/anchor_heads/solov2_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ def _conv_pred(self, conv_feat, num_filters, is_test, name, name_feat=None):
def _points_nms(self, heat, kernel=2):
hmax = fluid.layers.pool2d(
input=heat, pool_size=kernel, pool_type='max', pool_padding=1)
keep = fluid.layers.cast((hmax[:, :, :-1, :-1] == heat), 'float32')
return heat * keep
keep = fluid.layers.cast(
paddle.equal(hmax[:, :, :-1, :-1], heat), 'float32')
return paddle.multiply(heat, keep)

def _split_feats(self, feats):
return (paddle.nn.functional.interpolate(
Expand Down Expand Up @@ -376,7 +377,7 @@ def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
strides.append(
fluid.layers.fill_constant(
shape=[int(size_trans[_ind])],
dtype="int32",
dtype="float32",
value=self.segm_strides[_ind]))
strides = fluid.layers.concat(strides)
strides = fluid.layers.gather(strides, index=inds[:, 0])
Expand All @@ -389,7 +390,7 @@ def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
seg_masks = fluid.layers.cast(seg_masks, 'float32')
sum_masks = fluid.layers.reduce_sum(seg_masks, dim=[1, 2])

keep = fluid.layers.where(sum_masks > strides)
keep = fluid.layers.where(paddle.greater_than(sum_masks, strides))
keep = fluid.layers.squeeze(keep, axes=[1])
# Prevent empty and increase fake data
keep_other = fluid.layers.concat([
Expand All @@ -409,9 +410,10 @@ def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
cate_scores = fluid.layers.gather(cate_scores, index=keep_scores)

# mask scoring.
seg_mul = fluid.layers.cast(seg_preds * seg_masks, 'float32')
seg_scores = fluid.layers.reduce_sum(seg_mul, dim=[1, 2]) / sum_masks
cate_scores *= seg_scores
seg_mul = fluid.layers.cast(
paddle.multiply(seg_preds, seg_masks), 'float32')
seg_scores = paddle.divide(paddle.sum(seg_mul, axis=[1, 2]), sum_masks)
cate_scores = paddle.multiply(cate_scores, seg_scores)

# Matrix NMS
seg_preds, cate_scores, cate_labels = self.mask_nms(
Expand Down
3 changes: 2 additions & 1 deletion ppdet/modeling/backbones/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from collections import OrderedDict
import copy
import paddle
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Xavier
Expand Down Expand Up @@ -105,7 +106,7 @@ def _add_topdown_lateral(self, body_name, body_input, upper_output):
out_shape=[body_input.shape[2], body_input.shape[3]],
name=topdown_name)

return lateral + topdown
return paddle.add(lateral, topdown)

def get_output(self, body_dict):
"""
Expand Down
10 changes: 6 additions & 4 deletions ppdet/modeling/losses/solov2_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def _dice_loss(self, input, target):
target = fluid.layers.reshape(
target, shape=(fluid.layers.shape(target)[0], -1))
target = fluid.layers.cast(target, 'float32')
a = fluid.layers.reduce_sum(input * target, dim=1)
b = fluid.layers.reduce_sum(input * input, dim=1) + 0.001
c = fluid.layers.reduce_sum(target * target, dim=1) + 0.001
d = (2 * a) / (b + c)
a = fluid.layers.reduce_sum(paddle.multiply(input, target), dim=1)
b = fluid.layers.reduce_sum(
paddle.multiply(input, input), dim=1) + 0.001
c = fluid.layers.reduce_sum(
paddle.multiply(target, target), dim=1) + 0.001
d = paddle.divide((2 * a), paddle.add(b, c))
return 1 - d

def __call__(self, ins_pred_list, ins_label_list, cate_preds, cate_labels,
Expand Down
22 changes: 14 additions & 8 deletions ppdet/modeling/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,45 +1642,51 @@ def __call__(self,
sum_masks, expand_times=[n_samples]),
shape=[n_samples, n_samples])
# iou.
iou_matrix = (inter_matrix / (sum_masks_x + fluid.layers.transpose(
sum_masks_x, [1, 0]) - inter_matrix))
iou_matrix = paddle.divide(inter_matrix,
paddle.subtract(
paddle.add(sum_masks_x,
fluid.layers.transpose(
sum_masks_x, [1, 0])),
inter_matrix))
iou_matrix = paddle.triu(iou_matrix, diagonal=1)
# label_specific matrix.
cate_labels_x = fluid.layers.reshape(
fluid.layers.expand(
cate_labels, expand_times=[n_samples]),
shape=[n_samples, n_samples])
label_matrix = fluid.layers.cast(
(cate_labels_x == fluid.layers.transpose(cate_labels_x, [1, 0])),
paddle.equal(cate_labels_x,
fluid.layers.transpose(cate_labels_x, [1, 0])),
'float32')
label_matrix = paddle.triu(label_matrix, diagonal=1)

# IoU compensation
compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
compensate_iou = paddle.max(paddle.multiply(iou_matrix, label_matrix),
axis=0)
compensate_iou = fluid.layers.reshape(
fluid.layers.expand(
compensate_iou, expand_times=[n_samples]),
shape=[n_samples, n_samples])
compensate_iou = fluid.layers.transpose(compensate_iou, [1, 0])

# IoU decay
decay_iou = iou_matrix * label_matrix
decay_iou = paddle.multiply(iou_matrix, label_matrix)

# matrix nms
if self.kernel == 'gaussian':
decay_matrix = fluid.layers.exp(-1 * self.sigma * (decay_iou**2))
compensate_matrix = fluid.layers.exp(-1 * self.sigma *
(compensate_iou**2))
decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
axis=0)
decay_coefficient = paddle.min(
paddle.divide(decay_matrix, compensate_matrix), axis=0)
elif self.kernel == 'linear':
decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
decay_coefficient = paddle.min(decay_matrix, axis=0)
else:
raise NotImplementedError

# update the score.
cate_scores = cate_scores * decay_coefficient
cate_scores = paddle.multiply(cate_scores, decay_coefficient)

keep = fluid.layers.where(cate_scores >= self.update_threshold)
keep = fluid.layers.squeeze(keep, axes=[1])
Expand Down

0 comments on commit 5afa83e

Please sign in to comment.