diff --git a/docs/en/04-supported-codebases/mmdet.md b/docs/en/04-supported-codebases/mmdet.md
index 84e1fe5922..dba7b25d27 100644
--- a/docs/en/04-supported-codebases/mmdet.md
+++ b/docs/en/04-supported-codebases/mmdet.md
@@ -190,35 +190,40 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
## Supported models
-| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO |
-| :------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: |
-| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | Object Detection | Y | Y | N | N | Y |
-| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | Object Detection | Y | Y | Y | N | Y |
-| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | Object Detection | Y | N | N | N | Y |
-| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y |
-| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y |
-| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | Object Detection | Y | Y | Y | N | Y |
-| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | Object Detection | N | N | N | N | Y |
-| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | Object Detection | Y | Y | Y | N | Y |
-| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | Object Detection | Y | Y | Y | N | Y |
-| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y |
-| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
-| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
-| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
-| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
-| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y |
-| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
-| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
-| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
-| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
-| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
-| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
-| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
-| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
-| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
-| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
+| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO |
+| :-----------------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: |
+| [ATSS](https://github.com/open-mmlab/mmdetection/tree/main/configs/atss) | Object Detection | Y | Y | N | N | Y |
+| [FCOS](https://github.com/open-mmlab/mmdetection/tree/main/configs/fcos) | Object Detection | Y | Y | Y | N | Y |
+| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/main/configs/foveabox) | Object Detection | Y | N | N | N | Y |
+| [FSAF](https://github.com/open-mmlab/mmdetection/tree/main/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y |
+| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y |
+| [SSD](https://github.com/open-mmlab/mmdetection/tree/main/configs/ssd) | Object Detection | Y | Y | Y | N | Y |
+| [VFNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/vfnet) | Object Detection | N | N | N | N | Y |
+| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolo) | Object Detection | Y | Y | Y | N | Y |
+| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolox) | Object Detection | Y | Y | Y | N | Y |
+| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y |
+| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
+| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
+| [GFL](https://github.com/open-mmlab/mmdetection/tree/main/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
+| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/main/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
+| [DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [Deformable DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/deformable_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [Conditional DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/conditional_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [DAB-DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/dab_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [DINO](https://github.com/open-mmlab/mmdetection/tree/main/configs/dino)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
+| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
+| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
+| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
+| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
+| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
+| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
+| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
+| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
+| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
## Reminder
- For transformer based models, strongly suggest use `TensorRT>=8.4`.
- Mask2Former should use `TensorRT>=8.6.1` for dynamic shape inference.
+- DETR-like models do not support multi-batch inference.
diff --git a/docs/zh_cn/04-supported-codebases/mmdet.md b/docs/zh_cn/04-supported-codebases/mmdet.md
index 17c501630f..37bfe072a1 100644
--- a/docs/zh_cn/04-supported-codebases/mmdet.md
+++ b/docs/zh_cn/04-supported-codebases/mmdet.md
@@ -192,35 +192,40 @@ cv2.imwrite('output_detection.png', img)
## 模型支持列表
-| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO |
-| :------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: |
-| [ATSS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/atss) | Object Detection | Y | Y | N | N | Y |
-| [FCOS](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fcos) | Object Detection | Y | Y | Y | N | Y |
-| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/foveabox) | Object Detection | Y | N | N | N | Y |
-| [FSAF](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y |
-| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y |
-| [SSD](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/ssd) | Object Detection | Y | Y | Y | N | Y |
-| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | Object Detection | N | N | N | N | Y |
-| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolo) | Object Detection | Y | Y | Y | N | Y |
-| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/yolox) | Object Detection | Y | Y | Y | N | Y |
-| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y |
-| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
-| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
-| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
-| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
-| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y |
-| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
-| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
-| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
-| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
-| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
-| [SOLO](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
-| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
-| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
-| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
-| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
+| Model | Task | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO |
+| :-----------------------------------------------------------------------------------------------------------------: | :-------------------: | :---------: | :------: | :--: | :---: | :------: |
+| [ATSS](https://github.com/open-mmlab/mmdetection/tree/main/configs/atss) | Object Detection | Y | Y | N | N | Y |
+| [FCOS](https://github.com/open-mmlab/mmdetection/tree/main/configs/fcos) | Object Detection | Y | Y | Y | N | Y |
+| [FoveaBox](https://github.com/open-mmlab/mmdetection/tree/main/configs/foveabox) | Object Detection | Y | N | N | N | Y |
+| [FSAF](https://github.com/open-mmlab/mmdetection/tree/main/configs/fsaf) | Object Detection | Y | Y | Y | Y | Y |
+| [RetinaNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/retinanet) | Object Detection | Y | Y | Y | Y | Y |
+| [SSD](https://github.com/open-mmlab/mmdetection/tree/main/configs/ssd) | Object Detection | Y | Y | Y | N | Y |
+| [VFNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/vfnet) | Object Detection | N | N | N | N | Y |
+| [YOLOv3](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolo) | Object Detection | Y | Y | Y | N | Y |
+| [YOLOX](https://github.com/open-mmlab/mmdetection/tree/main/configs/yolox) | Object Detection | Y | Y | Y | N | Y |
+| [Cascade R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Object Detection | Y | Y | N | Y | Y |
+| [Faster R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
+| [Faster R-CNN + DCN](https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn) | Object Detection | Y | Y | Y | Y | Y |
+| [GFL](https://github.com/open-mmlab/mmdetection/tree/main/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
+| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/main/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
+| [DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [Deformable DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/deformable_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [Conditional DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/conditional_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [DAB-DETR](https://github.com/open-mmlab/mmdetection/tree/main/configs/dab_detr)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [DINO](https://github.com/open-mmlab/mmdetection/tree/main/configs/dino)[\*](#nobatchinfer) | Object Detection | Y | Y | N | ? | Y |
+| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
+| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
+| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
+| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
+| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
+| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
+| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
+| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
+| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
+| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
## 注意事项
- 强烈建议使用`TensorRT>=8.4`来转换基于 `transformer` 的模型.
- Mask2Former 请使用 `TensorRT>=8.6.1` 以保证动态尺寸正常推理.
+- DETR系列模型 不支持多批次推理.
diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py
index 08121bdfbb..bb2bdee2a8 100644
--- a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py
+++ b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py
@@ -8,6 +8,8 @@
from mmdeploy.core import FUNCTION_REWRITER
+@FUNCTION_REWRITER.register_rewriter(
+ 'mmdet.models.dense_heads.DeformableDETRHead.predict_by_feat')
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.DETRHead.predict_by_feat')
def detrhead__predict_by_feat__default(self,
@@ -17,10 +19,15 @@ def detrhead__predict_by_feat__default(self,
rescale: bool = True):
"""Rewrite `predict_by_feat` of `FoveaHead` for default backend."""
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy
+
cls_scores = all_cls_scores_list[-1]
bbox_preds = all_bbox_preds_list[-1]
-
img_shape = batch_img_metas[0]['img_shape']
+ if isinstance(img_shape, list):
+ img_shape = torch.tensor(
+ img_shape, dtype=torch.long, device=cls_scores.device)
+ img_shape = img_shape.unsqueeze(0)
+
max_per_img = self.test_cfg.get('max_per_img', len(cls_scores[0]))
batch_size = cls_scores.size(0)
# `batch_index_offset` is used for the gather of concatenated tensor
@@ -49,19 +56,9 @@ def detrhead__predict_by_feat__default(self,
...].squeeze(-1)
det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds)
-
- if isinstance(img_shape, torch.Tensor):
- hw = img_shape.flip(0).to(det_bboxes.device)
- else:
- hw = det_bboxes.new_tensor([img_shape[1], img_shape[0]])
- shape_scale = torch.cat([hw, hw])
- shape_scale = shape_scale.view(1, 1, -1)
+ det_bboxes.clamp_(min=0., max=1.)
+ shape_scale = img_shape.flip(1).repeat(1, 2).unsqueeze(1)
det_bboxes = det_bboxes * shape_scale
- # dynamically clip bboxes
- x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1)
- from mmdeploy.codebase.mmdet.deploy import clip_bboxes
- x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, img_shape)
- det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1)
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1)
return det_bboxes, det_labels
diff --git a/mmdeploy/codebase/mmdet/models/detectors/__init__.py b/mmdeploy/codebase/mmdet/models/detectors/__init__.py
index 460694aa72..ac1d82d7a6 100644
--- a/mmdeploy/codebase/mmdet/models/detectors/__init__.py
+++ b/mmdeploy/codebase/mmdet/models/detectors/__init__.py
@@ -3,6 +3,10 @@
single_stage, single_stage_instance_seg, two_stage)
__all__ = [
- 'base_detr', 'single_stage', 'single_stage_instance_seg', 'two_stage',
- 'panoptic_two_stage_segmentor', 'maskformer'
+ 'base_detr',
+ 'single_stage',
+ 'single_stage_instance_seg',
+ 'two_stage',
+ 'panoptic_two_stage_segmentor',
+ 'maskformer',
]
diff --git a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py
index 3531c9183c..42c0cf45f6 100644
--- a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py
+++ b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py
@@ -47,8 +47,8 @@ def _set_metainfo(data_samples, img_shape):
@FUNCTION_REWRITER.register_rewriter(
- 'mmdet.models.detectors.base_detr.DetectionTransformer.predict')
-def detection_transformer__predict(self,
+ 'mmdet.models.detectors.base_detr.DetectionTransformer.forward')
+def detection_transformer__forward(self,
batch_inputs: torch.Tensor,
data_samples: OptSampleList = None,
rescale: bool = True,
@@ -79,11 +79,11 @@ def detection_transformer__predict(self,
# get origin input shape as tensor to support onnx dynamic shape
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
- img_shape = torch._shape_as_tensor(batch_inputs)[2:]
+ img_shape = torch._shape_as_tensor(batch_inputs)[2:].to(
+ batch_inputs.device)
if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape]
# set the metainfo
data_samples = _set_metainfo(data_samples, img_shape)
-
return __predict_impl(self, batch_inputs, data_samples, rescale)
diff --git a/mmdeploy/utils/config_utils.py b/mmdeploy/utils/config_utils.py
index 5565596fee..6af418421a 100644
--- a/mmdeploy/utils/config_utils.py
+++ b/mmdeploy/utils/config_utils.py
@@ -245,7 +245,7 @@ def is_dynamic_shape(deploy_cfg: Union[str, mmengine.Config],
return False
# check if 2 (height) and 3 (width) in input axes
- if 2 in input_axes and 3 in input_axes:
+ if 2 in input_axes or 3 in input_axes:
return True
return False
diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml
index 1df7404e5e..f0e813ce8e 100644
--- a/tests/regression/mmdet.yml
+++ b/tests/regression/mmdet.yml
@@ -311,10 +311,7 @@ models:
- configs/detr/detr_r50_8xb2-150e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- - deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py
- convert_image: *convert_image
- backend_test: *default_backend_test
- sdk_config: *sdk_dynamic
+ - *pipeline_trt_dynamic_fp32
- name: CenterNet
metafile: configs/centernet/metafile.yml
@@ -416,3 +413,35 @@ models:
- deploy_config: configs/mmdet/panoptic-seg/panoptic-seg_maskformer_tensorrt_static-800x1344.py
convert_image: *convert_image
backend_test: *default_backend_test
+
+ - name: DINO
+ metafile: configs/dino/metafile.yml
+ model_configs:
+ - configs/dino/dino-4scale_r50_8xb2-12e_coco.py
+ pipelines:
+ - *pipeline_ort_dynamic_fp32
+ - *pipeline_trt_dynamic_fp32
+
+ - name: ConditionalDETR
+ metafile: configs/conditional_detr/metafile.yml
+ model_configs:
+ - configs/conditional_detr/conditional-detr_r50_8xb2-50e_coco.py
+ pipelines:
+ - *pipeline_ort_dynamic_fp32
+ - *pipeline_trt_dynamic_fp32
+
+ - name: DAB-DETR
+ metafile: configs/dab_detr/metafile.yml
+ model_configs:
+ - configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py
+ pipelines:
+ - *pipeline_ort_dynamic_fp32
+ - *pipeline_trt_dynamic_fp32
+
+ - name: DeformableDETR
+ metafile: configs/deformable_detr/metafile.yml
+ model_configs:
+ - configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py
+ pipelines:
+ - *pipeline_ort_dynamic_fp32
+ - *pipeline_trt_dynamic_fp32
diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py
index 78b3255b06..ca1c5c1255 100644
--- a/tests/test_codebase/test_mmdet/test_mmdet_models.py
+++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py
@@ -727,7 +727,7 @@ def test_predict_of_detr_detector(model_cfg_path, backend):
from mmdet.structures import DetDataSample
data_sample = DetDataSample(metainfo=dict(batch_input_shape=(64, 64)))
rewrite_inputs = {'batch_inputs': img}
- wrapped_model = WrapModel(model, 'predict', data_samples=[data_sample])
+ wrapped_model = WrapModel(model, 'forward', data_samples=[data_sample])
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,