diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py index a04635673..595a2e616 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py @@ -17,10 +17,10 @@ class PostProcessor(nn.Module): """ def __init__( - self, - score_thresh=0.05, - nms=0.5, - detections_per_img=100, + self, + score_thresh=0.05, + nms=0.5, + detections_per_img=100, box_coder=None, cls_agnostic_bbox_reg=False ): @@ -123,7 +123,7 @@ def filter_results(self, boxlist, num_classes): boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") boxlist_for_class.add_field("scores", scores_j) boxlist_for_class = boxlist_nms( - boxlist_for_class, self.nms, score_field="scores" + boxlist_for_class, self.nms ) num_labels = len(boxlist_for_class) boxlist_for_class.add_field( @@ -158,9 +158,9 @@ def make_roi_box_post_processor(cfg): cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG postprocessor = PostProcessor( - score_thresh, - nms_thresh, - detections_per_img, + score_thresh, + nms_thresh, + detections_per_img, box_coder, cls_agnostic_bbox_reg ) diff --git a/maskrcnn_benchmark/structures/boxlist_ops.py b/maskrcnn_benchmark/structures/boxlist_ops.py index 45160f9ab..dc51212f4 100644 --- a/maskrcnn_benchmark/structures/boxlist_ops.py +++ b/maskrcnn_benchmark/structures/boxlist_ops.py @@ -6,7 +6,7 @@ from maskrcnn_benchmark.layers import nms as _box_nms -def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="score"): +def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"): """ Performs non-maximum suppression on a boxlist, with scores specified in a boxlist field via score_field. @@ -15,7 +15,7 @@ def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="score"): boxlist(BoxList) nms_thresh (float) max_proposals (int): if > 0, then only the top max_proposals are kept - after non-maxium suppression + after non-maximum suppression score_field (str) """ if nms_thresh <= 0: