Skip to content

Commit

Permalink
[Fix] Fix VFNet test (open-mmlab#281)
Browse files Browse the repository at this point in the history
* [Fix] fix bugs for mmcls performance test (open-mmlab#269)

* fix bugs for mmcls performance test

* fix yapf

* add comments of CLASSES attribute

* Fix test_get_bboxes_of_vfnet_head

* Fix

Co-authored-by: hanrui1sensetime <[email protected]>
  • Loading branch information
SemyonBevzuk and hanrui1sensetime authored Dec 13, 2021
1 parent 0f90a0a commit a96e5f9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 54 deletions.
1 change: 0 additions & 1 deletion docs/src/pytorch-sphinx-theme
Submodule pytorch-sphinx-theme deleted from d2ed95
8 changes: 2 additions & 6 deletions mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,9 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
cls_score = sum(ms_scores) / float(len(ms_scores))
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
rois = rois.reshape(batch_size, num_proposals_per_img, -1)
scale_factor = img_metas[0].get('scale_factor', None)
det_bboxes, det_labels = self.bbox_head[-1].get_bboxes(
rois,
cls_score,
bbox_pred,
max_shape,
img_metas[0]['scale_factor'],
cfg=rcnn_test_cfg)
rois, cls_score, bbox_pred, max_shape, scale_factor, cfg=rcnn_test_cfg)

if not self.with_mask:
return det_bboxes, det_labels
Expand Down
55 changes: 8 additions & 47 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,44 +1053,13 @@ def get_vfnet_head_model():
return model


@pytest.mark.parametrize('backend_type', [Backend.OPENVINO])
@pytest.mark.parametrize('backend_type',
[Backend.OPENVINO, Backend.ONNXRUNTIME])
def test_get_bboxes_of_vfnet_head(backend_type: Backend):
"""Test get_bboxes rewrite of VFNet head."""
check_backend(backend_type)

class TestModel(torch.nn.Module):
"""Stub for VFNetHead with fake bbox_preds operations.
Then bbox_preds will be one of the inputs to the ONNX graph.
"""

def __init__(self, vfnet_head):
super().__init__()
self.vfnet_head = vfnet_head

def get_bboxes(self,
cls_scores,
bbox_preds,
bbox_preds_refine,
img_metas,
cfg=None,
rescale=None,
with_nms=True):
tmp_bbox_pred_refine = []
for bbox_pred, bbox_pred_refine in zip(bbox_preds,
bbox_preds_refine):
tmp = bbox_pred_refine + bbox_pred
tmp = tmp - bbox_pred
tmp_bbox_pred_refine.append(tmp)
bbox_preds_refine = tmp_bbox_pred_refine
return self.vfnet_head.get_bboxes(cls_scores, bbox_preds,
bbox_preds_refine, img_metas,
cfg, rescale, with_nms)

test_model = TestModel(get_vfnet_head_model())
test_model.requires_grad_(False)
test_model.cpu().eval()

vfnet_head = get_vfnet_head_model()
vfnet_head.cpu().eval()
s = 16
img_metas = [{
'scale_factor': np.ones(4),
Expand All @@ -1116,32 +1085,24 @@ def get_bboxes(self,

seed_everything(1234)
cls_score = [
torch.rand(1, test_model.vfnet_head.num_classes, pow(2, i), pow(2, i))
torch.rand(1, vfnet_head.num_classes, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
seed_everything(9101)
bbox_preds_refine = [
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]

model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'bbox_preds_refine': bbox_preds_refine,
'img_metas': img_metas
}
model_outputs = get_model_outputs(test_model, 'get_bboxes', model_inputs)
model_outputs = get_model_outputs(vfnet_head, 'get_bboxes', model_inputs)

img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
test_model, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'bbox_preds_refine': bbox_preds_refine
}
vfnet_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {'cls_scores': cls_score, 'bbox_preds': bboxes}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
Expand Down

0 comments on commit a96e5f9

Please sign in to comment.