Skip to content

Commit

Permalink
Support dino from mmdet (#2410)
Browse files Browse the repository at this point in the history
* detr batch infer

* support dino

* remove dynamic batch

* update doc

* disable exporting masks for image paddings in multi-batch inference

* fix

* remove rewriting and move changes to mmdet
RunningLeon authored Sep 12, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 455ec18 commit 985a4f3
Showing 8 changed files with 119 additions and 79 deletions.
59 changes: 32 additions & 27 deletions docs/en/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
@@ -190,35 +190,40 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter

## Supported models

| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO |
| :------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: |
| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | Object Detection | Y | Y | N | N | Y |
| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | Object Detection | Y | Y | Y | N | Y |
| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | Object Detection | Y | N | N | N | Y |
| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y |
| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y |
| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | Object Detection | Y | Y | Y | N | Y |
| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | Object Detection | N | N | N | N | Y |
| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | Object Detection | Y | Y | Y | N | Y |
| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | Object Detection | Y | Y | Y | N | Y |
| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y |
| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y |
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO |
| :-----------------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: |
| [ATSS](https://github.com/open-mmlab/mmdetection/tree/main/configs/atss) | Object Detection | Y | Y | N | N | Y |
| [FCOS](https://github.com/open-mmlab/mmdetection/tree/main/configs/fcos) | Object Detection | Y | Y | Y | N | Y |
| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/main/configs/foveabox) | Object Detection | Y | N | N | N | Y |
| [FSAF](https://github.com/open-mmlab/mmdetection/tree/main/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y |
| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y |
| [SSD](https://github.com/open-mmlab/mmdetection/tree/main/configs/ssd) | Object Detection | Y | Y | Y | N | Y |
| [VFNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/vfnet) | Object Detection | N | N | N | N | Y |
| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolo) | Object Detection | Y | Y | Y | N | Y |
| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolox) | Object Detection | Y | Y | Y | N | Y |
| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y |
| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
| [GFL](https://github.com/open-mmlab/mmdetection/tree/main/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/main/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
| [DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [Deformable DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/deformable_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [Conditional DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/conditional_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [DAB-DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/dab_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [DINO](https://github.com/open-mmlab/mmdetection/tree/main/configs/dino)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |

## Reminder

- For transformer based models, strongly suggest use `TensorRT>=8.4`.
- <i id="mask2former">Mask2Former</i> should use `TensorRT>=8.6.1` for dynamic shape inference.
- <i id="nobatchinfer">DETR-like models</i> do not support multi-batch inference.
59 changes: 32 additions & 27 deletions docs/zh_cn/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
@@ -192,35 +192,40 @@ cv2.imwrite('output_detection.png', img)

## 模型支持列表

| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO |
| :------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: |
| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | Object Detection | Y | Y | N | N | Y |
| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | Object Detection | Y | Y | Y | N | Y |
| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | Object Detection | Y | N | N | N | Y |
| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y |
| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y |
| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | Object Detection | Y | Y | Y | N | Y |
| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | Object Detection | N | N | N | N | Y |
| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | Object Detection | Y | Y | Y | N | Y |
| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | Object Detection | Y | Y | Y | N | Y |
| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y |
| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y |
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO |
| :-----------------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: |
| [ATSS](https://github.com/open-mmlab/mmdetection/tree/main/configs/atss) | Object Detection | Y | Y | N | N | Y |
| [FCOS](https://github.com/open-mmlab/mmdetection/tree/main/configs/fcos) | Object Detection | Y | Y | Y | N | Y |
| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/main/configs/foveabox) | Object Detection | Y | N | N | N | Y |
| [FSAF](https://github.com/open-mmlab/mmdetection/tree/main/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y |
| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y |
| [SSD](https://github.com/open-mmlab/mmdetection/tree/main/configs/ssd) | Object Detection | Y | Y | Y | N | Y |
| [VFNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/vfnet) | Object Detection | N | N | N | N | Y |
| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolo) | Object Detection | Y | Y | Y | N | Y |
| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolox) | Object Detection | Y | Y | Y | N | Y |
| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y |
| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
| [GFL](https://github.com/open-mmlab/mmdetection/tree/main/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/main/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
| [DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [Deformable DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/deformable_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [Conditional DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/conditional_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [DAB-DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/dab_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [DINO](https://github.com/open-mmlab/mmdetection/tree/main/configs/dino)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |

## 注意事项

- 强烈建议使用`TensorRT>=8.4`来转换基于 `transformer` 的模型.
- <i id="mask2former">Mask2Former</i> 请使用 `TensorRT>=8.6.1` 以保证动态尺寸正常推理.
- <i id="nobatchinfer">DETR系列模型</i> 不支持多批次推理.
23 changes: 10 additions & 13 deletions mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@
from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.DeformableDETRHead.predict_by_feat')
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.DETRHead.predict_by_feat')
def detrhead__predict_by_feat__default(self,
@@ -17,10 +19,15 @@ def detrhead__predict_by_feat__default(self,
rescale: bool = True):
"""Rewrite `predict_by_feat` of `FoveaHead` for default backend."""
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy

cls_scores = all_cls_scores_list[-1]
bbox_preds = all_bbox_preds_list[-1]

img_shape = batch_img_metas[0]['img_shape']
if isinstance(img_shape, list):
img_shape = torch.tensor(
img_shape, dtype=torch.long, device=cls_scores.device)
img_shape = img_shape.unsqueeze(0)

max_per_img = self.test_cfg.get('max_per_img', len(cls_scores[0]))
batch_size = cls_scores.size(0)
# `batch_index_offset` is used for the gather of concatenated tensor
@@ -49,19 +56,9 @@ def detrhead__predict_by_feat__default(self,
...].squeeze(-1)

det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds)

if isinstance(img_shape, torch.Tensor):
hw = img_shape.flip(0).to(det_bboxes.device)
else:
hw = det_bboxes.new_tensor([img_shape[1], img_shape[0]])
shape_scale = torch.cat([hw, hw])
shape_scale = shape_scale.view(1, 1, -1)
det_bboxes.clamp_(min=0., max=1.)
shape_scale = img_shape.flip(1).repeat(1, 2).unsqueeze(1)
det_bboxes = det_bboxes * shape_scale
# dynamically clip bboxes
x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1)
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, img_shape)
det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1)
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1)

return det_bboxes, det_labels
8 changes: 6 additions & 2 deletions mmdeploy/codebase/mmdet/models/detectors/__init__.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,10 @@
single_stage, single_stage_instance_seg, two_stage)

__all__ = [
'base_detr', 'single_stage', 'single_stage_instance_seg', 'two_stage',
'panoptic_two_stage_segmentor', 'maskformer'
'base_detr',
'single_stage',
'single_stage_instance_seg',
'two_stage',
'panoptic_two_stage_segmentor',
'maskformer',
]
8 changes: 4 additions & 4 deletions mmdeploy/codebase/mmdet/models/detectors/base_detr.py
Original file line number Diff line number Diff line change
@@ -47,8 +47,8 @@ def _set_metainfo(data_samples, img_shape):


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base_detr.DetectionTransformer.predict')
def detection_transformer__predict(self,
'mmdet.models.detectors.base_detr.DetectionTransformer.forward')
def detection_transformer__forward(self,
batch_inputs: torch.Tensor,
data_samples: OptSampleList = None,
rescale: bool = True,
@@ -79,11 +79,11 @@ def detection_transformer__predict(self,

# get origin input shape as tensor to support onnx dynamic shape
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
img_shape = torch._shape_as_tensor(batch_inputs)[2:]
img_shape = torch._shape_as_tensor(batch_inputs)[2:].to(
batch_inputs.device)
if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape]

# set the metainfo
data_samples = _set_metainfo(data_samples, img_shape)

return __predict_impl(self, batch_inputs, data_samples, rescale)
2 changes: 1 addition & 1 deletion mmdeploy/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -245,7 +245,7 @@ def is_dynamic_shape(deploy_cfg: Union[str, mmengine.Config],
return False

# check if 2 (height) and 3 (width) in input axes
if 2 in input_axes and 3 in input_axes:
if 2 in input_axes or 3 in input_axes:
return True

return False
37 changes: 33 additions & 4 deletions tests/regression/mmdet.yml
Original file line number Diff line number Diff line change
@@ -311,10 +311,7 @@ models:
- configs/detr/detr_r50_8xb2-150e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py
convert_image: *convert_image
backend_test: *default_backend_test
sdk_config: *sdk_dynamic
- *pipeline_trt_dynamic_fp32

- name: CenterNet
metafile: configs/centernet/metafile.yml
@@ -416,3 +413,35 @@ models:
- deploy_config: configs/mmdet/panoptic-seg/panoptic-seg_maskformer_tensorrt_static-800x1344.py
convert_image: *convert_image
backend_test: *default_backend_test

- name: DINO
metafile: configs/dino/metafile.yml
model_configs:
- configs/dino/dino-4scale_r50_8xb2-12e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32

- name: ConditionalDETR
metafile: configs/conditional_detr/metafile.yml
model_configs:
- configs/conditional_detr/conditional-detr_r50_8xb2-50e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32

- name: DAB-DETR
metafile: configs/dab_detr/metafile.yml
model_configs:
- configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32

- name: DeformableDETR
metafile: configs/deformable_detr/metafile.yml
model_configs:
- configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32
2 changes: 1 addition & 1 deletion tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
@@ -727,7 +727,7 @@ def test_predict_of_detr_detector(model_cfg_path, backend):
from mmdet.structures import DetDataSample
data_sample = DetDataSample(metainfo=dict(batch_input_shape=(64, 64)))
rewrite_inputs = {'batch_inputs': img}
wrapped_model = WrapModel(model, 'predict', data_samples=[data_sample])
wrapped_model = WrapModel(model, 'forward', data_samples=[data_sample])
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,

0 comments on commit 985a4f3

Please sign in to comment.