diff --git a/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp index c3ad05bf91..15192c2732 100644 --- a/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -3551,6 +3551,8 @@ int main(int argc, char** argv) { } } else if (op == "Where") { fprintf(pp, "%-16s", "Where"); + } else if (op == "Yolov3DetectionOutput") { + fprintf(pp, "%-16s", "Yolov3DetectionOutput"); } else { // TODO fprintf(stderr, "%s not supported yet!\n", op.c_str()); @@ -5382,6 +5384,38 @@ int main(int argc, char** argv) { fprintf(pp, ",%d", axes[i]); } } + } else if (op == "Yolov3DetectionOutput") { + int num_class = get_node_attr_i(node, "num_class"); + int num_box = get_node_attr_i(node, "num_box"); + float confidence_threshold = + get_node_attr_f(node, "confidence_threshold"); + float nms_threshold = get_node_attr_f(node, "nms_threshold"); + fprintf(pp, " 0=%d", num_class); + fprintf(pp, " 1=%d", num_box); + fprintf(pp, " 2=%e", confidence_threshold); + fprintf(pp, " 3=%e", nms_threshold); + std::vector biases = get_node_attr_af(node, "biases"); + if (biases.size() > 0) { + fprintf(pp, " -23304=%zu", biases.size()); + for (int i = 0; i < (int)biases.size(); i++) { + fprintf(pp, ",%e", biases[i]); + } + } + std::vector mask = get_node_attr_af(node, "mask"); + if (mask.size() > 0) { + fprintf(pp, " -23305=%zu", mask.size()); + for (int i = 0; i < (int)mask.size(); i++) { + fprintf(pp, ",%e", mask[i]); + } + } + std::vector anchors_scale = + get_node_attr_af(node, "anchors_scale"); + if (anchors_scale.size() > 0) { + fprintf(pp, " -23306=%zu", anchors_scale.size()); + for (int i = 0; i < (int)anchors_scale.size(); i++) { + fprintf(pp, ",%e", anchors_scale[i]); + } + } } else { // TODO op specific param } diff --git a/configs/mmdet/_base_/base_static.py b/configs/mmdet/_base_/base_static.py index 3a33217971..9fe0d3433b 100644 --- a/configs/mmdet/_base_/base_static.py +++ b/configs/mmdet/_base_/base_static.py @@ -4,6 +4,7 @@ codebase_config = dict( type='mmdet', task='ObjectDetection', + model_type='end2end', post_processing=dict( score_threshold=0.05, confidence_threshold=0.005, # for YOLOv3 diff --git a/configs/mmdet/detection/single-stage_ncnn_dynamic.py b/configs/mmdet/detection/single-stage_ncnn_dynamic.py new file mode 100644 index 0000000000..9183da8608 --- /dev/null +++ b/configs/mmdet/detection/single-stage_ncnn_dynamic.py @@ -0,0 +1,4 @@ +_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ncnn.py'] + +codebase_config = dict(model_type='ncnn_end2end') +onnx_config = dict(output_names=['detection_output'], input_shape=None) diff --git a/mmdeploy/backend/ncnn/wrapper.py b/mmdeploy/backend/ncnn/wrapper.py index 239455ffe1..b4bbdacfcc 100644 --- a/mmdeploy/backend/ncnn/wrapper.py +++ b/mmdeploy/backend/ncnn/wrapper.py @@ -77,6 +77,7 @@ def forward(self, inputs: Dict[str, """ input_list = list(inputs.values()) batch_size = input_list[0].size(0) + assert batch_size == 1, 'Only batch_size=1 is supported!' for input_tensor in input_list[1:]: assert input_tensor.size( 0) == batch_size, 'All tensors should have same batch size' @@ -89,29 +90,23 @@ def forward(self, inputs: Dict[str, # create output dict outputs = dict([name, [None] * batch_size] for name in output_names) - # run inference - for batch_id in range(batch_size): - # create extractor - ex = self._net.create_extractor() - - # set inputs - for name, input_tensor in inputs.items(): - data = input_tensor[batch_id].contiguous() - data = data.detach().cpu().numpy() - input_mat = ncnn.Mat(data) - ex.input(name, input_mat) - - # get outputs - result = self.__ncnn_execute( - extractor=ex, output_names=output_names) - for name in output_names: - outputs[name][batch_id] = torch.from_numpy( - np.array(result[name])) - - # stack outputs together - for name, output_tensor in outputs.items(): - outputs[name] = torch.stack(output_tensor) + # create extractor + ex = self._net.create_extractor() + # set inputs + for name, input_tensor in inputs.items(): + data = input_tensor[0].contiguous().cpu().numpy() + input_mat = ncnn.Mat(data) + ex.input(name, input_mat) + # get outputs + result = self.__ncnn_execute(extractor=ex, output_names=output_names) + for name in output_names: + mat = result[name] + # deal with special case + if mat.empty(): + outputs[name] = None + continue + outputs[name] = torch.from_numpy(np.array(mat)).unsqueeze(0) return outputs @TimeCounter.count_time() diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index 8a8259f18e..f4786c8e04 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -14,8 +14,8 @@ from mmdeploy.backend.base import get_backend_file_count from mmdeploy.codebase.base import BaseBackendModel from mmdeploy.codebase.mmdet import get_post_processing_params, multiclass_nms -from mmdeploy.utils import (Backend, get_backend, get_onnx_config, - get_partition_config, load_config) +from mmdeploy.utils import (Backend, get_backend, get_codebase_config, + get_onnx_config, get_partition_config, load_config) def __build_backend_model(partition_name: str, backend: Backend, @@ -259,7 +259,7 @@ def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \ def show_result(self, img: np.ndarray, result: list, - win_name: str, + win_name: str = '', show: bool = True, score_thr: float = 0.3, out_file=None): @@ -516,6 +516,61 @@ class labels of shape [N, num_det]. return outputs +@__BACKEND_MODEL.register_module('ncnn_end2end') +class NCNNEnd2EndModel(End2EndModel): + """NCNNEnd2EndModel. + + End2end NCNN model inference class. Because it has DetectionOutput layer + and its output is different from original mmdet style of `dets`, `labels`. + + Args: + model_file (str): The path of input model file. + class_names (Sequence[str]): A list of string specifying class names. + model_cfg: (str | mmcv.Config): Input model config. + deploy_cfg: (str | mmcv.Config): Input deployment config. + device_id (int): An integer represents device index. + """ + + def __init__(self, backend: Backend, backend_files: Sequence[str], + device: str, class_names: Sequence[str], + model_cfg: Union[str, mmcv.Config], + deploy_cfg: Union[str, mmcv.Config], **kwargs): + assert backend == Backend.NCNN, f'only supported ncnn, but give \ + {backend.value}' + + super(NCNNEnd2EndModel, + self).__init__(backend, backend_files, device, class_names, + deploy_cfg, **kwargs) + # load cfg if necessary + model_cfg = load_config(model_cfg)[0] + self.model_cfg = model_cfg + + def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> List: + """Implement forward test. + + Args: + imgs (torch.Tensor): Input image(s) in [N x C x H x W] format. + + Returns: + list[np.ndarray]: dets of shape [N, num_det, 5] and + class labels of shape [N, num_det]. + """ + _, _, H, W = imgs.shape + outputs = self.wrapper({'input': imgs}) + for key, item in outputs.items(): + if item is None: + return [np.zeros((1, 0, 6))] + out = self.wrapper.output_to_list(outputs)[0] + labels = out[:, :, 0] - 1 + scales = torch.tensor([W, H, W, H]).reshape(1, 1, 4) + scores = out[:, :, 1:2] + boxes = out[:, :, 2:6] * scales + dets = torch.cat([boxes, scores], dim=2) + dets = dets.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + return [dets, labels] + + def get_classes_from_config(model_cfg: Union[str, mmcv.Config], **kwargs): """Get class name from config. @@ -566,11 +621,13 @@ def build_object_detection_model(model_files: Sequence[str], backend = get_backend(deploy_cfg) class_names = get_classes_from_config(model_cfg) - # Default Config is 'end2end' - partition_type = 'end2end' partition_config = get_partition_config(deploy_cfg) if partition_config is not None: partition_type = partition_config.get('type', None) + else: + codebase_config = get_codebase_config(deploy_cfg) + # Default Config is 'end2end' + partition_type = codebase_config.get('model_type', 'end2end') backend_detector = __BACKEND_MODEL.build( partition_type, diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py index e313f4f21c..bd7c043f2f 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import numpy as np import torch from mmdeploy.codebase.mmdet import (get_post_processing_params, @@ -183,136 +184,72 @@ def yolov3_head__get_bboxes__ncnn(ctx, if None, test_cfg would be used. Default: None. Returns: - If with_nms == True: - tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels), - `dets` of shape [N, num_det, 5] and `labels` of shape - [N, num_det]. - Else: - tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores + Tensor: Detection_output of shape [num_boxes, 6], + each row is [label, score, x1, y1, x2, y2]. Note that + fore-ground class label in Yolov3DetectionOutput starts + from `1`. x1, y1, x2, y2 are normalized in range(0,1). """ num_levels = len(pred_maps) - pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)] - cfg = self.test_cfg if cfg is None else cfg - assert len(pred_maps_list) == self.num_levels - - device = pred_maps_list[0].device - batch_size = pred_maps_list[0].shape[0] - - featmap_sizes = [ - pred_maps_list[i].shape[-2:] for i in range(self.num_levels) - ] - multi_lvl_anchors = self.anchor_generator.grid_anchors( - featmap_sizes, device) - pre_topk = cfg.get('nms_pre', -1) - multi_lvl_bboxes = [] - multi_lvl_cls_scores = [] - multi_lvl_conf_scores = [] - for i in range(self.num_levels): - # get some key info for current scale - pred_map = pred_maps_list[i] - stride = self.featmap_strides[i] - # (b,h, w, num_anchors*num_attrib) -> - # (b,h*w*num_anchors, num_attrib) - pred_map = pred_map.permute(0, 2, 3, - 1).reshape(batch_size, -1, self.num_attrib) - # Inplace operation like - # ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])``` - # would create constant tensor when exporting to onnx - pred_map_conf = torch.sigmoid(pred_map[..., :2]) - pred_map_rest = pred_map[..., 2:] - # dim must be written as 2, but not -1, because ncnn implicit batch - # mechanism. - pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=2) - pred_map_boxes = pred_map[..., :4] - multi_lvl_anchor = multi_lvl_anchors[i] - # use static anchor if input shape is static - multi_lvl_anchor = multi_lvl_anchor.unsqueeze(0).expand_as( - pred_map_boxes).data - - bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, pred_map_boxes, - stride) - # conf and cls - conf_pred = torch.sigmoid(pred_map[..., 4]) - cls_pred = torch.sigmoid(pred_map[..., 5:]).view( - batch_size, -1, self.num_classes) # Cls pred one-hot. - - if pre_topk > 0: - _, topk_inds = conf_pred.topk(pre_topk) - topk_inds = topk_inds.view(-1) - bbox_pred = bbox_pred[:, topk_inds, :] - cls_pred = cls_pred[:, topk_inds, :] - conf_pred = conf_pred[:, topk_inds] - - # Save the result of current scale - multi_lvl_bboxes.append(bbox_pred) - multi_lvl_cls_scores.append(cls_pred) - multi_lvl_conf_scores.append(conf_pred) - - # Merge the results of different scales together - batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1) - batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1) - batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1) - post_params = get_post_processing_params(ctx.cfg) - score_threshold = cfg.get('score_thr', post_params.score_threshold) confidence_threshold = cfg.get('conf_thr', post_params.confidence_threshold) - - # helper function for creating Threshold op - def _create_threshold(x, thresh): - - class ThresholdOp(torch.autograd.Function): - """Create Threshold op.""" + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + anchor_biases = np.array( + self.anchor_generator.base_sizes).reshape(-1).tolist() + num_box = len(self.anchor_generator.base_sizes[0]) + bias_masks = list(range(num_levels * num_box)) + + def _create_yolov3_detection_output(): + """Help create Yolov3DetectionOutput op in ONNX.""" + + class Yolov3DetectionOutputOp(torch.autograd.Function): + """Create Yolov3DetectionOutput op. + + Args: + *inputs (Tensor): Multiple predicted feature maps. + num_class (int): Number of classes. + num_box (int): Number of box per grid. + confidence_threshold (float): Threshold of object + score. + nms_threshold (float): IoU threshold for NMS. + biases (List[float]: Base sizes to compute anchors + for each FPN. + mask (List[float]): Used to select base sizes in + biases. + anchors_scale (List[float]): Down-sampling scales of + each FPN layer, e.g.: [32, 16]. + """ @staticmethod - def forward(ctx, x, threshold): - return x > threshold + def forward(ctx, *args): + # create dummpy output of shape [num_boxes, 6], + # each row is [label, score, x1, y1, x2, y2] + output = torch.rand(100, 6) + return output @staticmethod - def symbolic(g, x, threshold): + def symbolic(g, *args): + anchors_scale = args[-1] + inputs = args[:len(anchors_scale)] + assert len(args) == (len(anchors_scale) + 7) return g.op( - 'mmdeploy::Threshold', x, threshold_f=threshold, outputs=1) - - return ThresholdOp.apply(x, thresh) - - # follow original pipeline of YOLOv3 - if confidence_threshold > 0: - mask = _create_threshold(batch_mlvl_conf_scores, - confidence_threshold).float() - batch_mlvl_conf_scores *= mask - if score_threshold > 0: - mask = _create_threshold(batch_mlvl_scores, score_threshold).float() - batch_mlvl_scores *= mask - - # NCNN broadcast needs the same in channel dimension. - _batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).unsqueeze(3) - _batch_mlvl_scores = batch_mlvl_scores.unsqueeze(3) - batch_mlvl_scores = (_batch_mlvl_scores * _batch_mlvl_conf_scores).reshape( - batch_mlvl_scores.shape) - # Although batch_mlvl_bboxes already has the shape of - # (batch_size, -1, 4), ncnn implicit batch mechanism in the model and - # ncnn channel alignment would result in a shape of - # (batch_size, -1, 4, 1). So, we need a reshape op to ensure the - # batch_mlvl_bboxes shape is right. - batch_mlvl_bboxes = batch_mlvl_bboxes.reshape(batch_size, -1, 4) - - if with_nms: - max_output_boxes_per_class = post_params.max_output_boxes_per_class - iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) - pre_top_k = post_params.pre_top_k - keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) - # keep aligned with original pipeline, improve - # mAP by 1% for YOLOv3 in ONNX - score_threshold = 0 - return multiclass_nms( - batch_mlvl_bboxes, - batch_mlvl_scores, - max_output_boxes_per_class, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pre_top_k=pre_top_k, - keep_top_k=keep_top_k) - else: - return batch_mlvl_bboxes, batch_mlvl_scores + 'mmdeploy::Yolov3DetectionOutput', + *inputs, + num_class_i=args[-7], + num_box_i=args[-6], + confidence_threshold_f=args[-5], + nms_threshold_f=args[-4], + biases_f=args[-3], + mask_f=args[-2], + anchors_scale_f=anchors_scale, + outputs=1) + + return Yolov3DetectionOutputOp.apply(*pred_maps, self.num_classes, + num_box, confidence_threshold, + iou_threshold, anchor_biases, + bias_masks, self.featmap_strides) + + output = _create_yolov3_detection_output() + return output diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py index 7c2357a752..c2f170f053 100644 --- a/mmdeploy/utils/__init__.py +++ b/mmdeploy/utils/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .config_utils import (cfg_apply_marks, get_backend, get_calib_config, - get_calib_filename, get_codebase, get_common_config, + get_calib_filename, get_codebase, + get_codebase_config, get_common_config, get_input_shape, get_model_inputs, get_onnx_config, get_partition_config, get_task_type, is_dynamic_batch, is_dynamic_shape, load_config) @@ -13,5 +14,5 @@ 'get_onnx_config', 'get_partition_config', 'get_calib_config', 'get_calib_filename', 'get_common_config', 'get_model_inputs', 'cfg_apply_marks', 'get_input_shape', 'parse_device_id', - 'parse_cuda_device_id' + 'parse_cuda_device_id', 'get_codebase_config' ] diff --git a/mmdeploy/utils/config_utils.py b/mmdeploy/utils/config_utils.py index 1192da3221..1ddd4f124a 100644 --- a/mmdeploy/utils/config_utils.py +++ b/mmdeploy/utils/config_utils.py @@ -30,7 +30,8 @@ def _load_config(cfg): return configs -def get_task_type(deploy_cfg: Union[str, mmcv.Config], default=None) -> Task: +def get_task_type(deploy_cfg: Union[str, mmcv.Config], + default: str = None) -> Task: """Get the task type of the algorithm. Args: @@ -42,15 +43,30 @@ def get_task_type(deploy_cfg: Union[str, mmcv.Config], default=None) -> Task: Task : An enumeration denotes the task type. """ - deploy_cfg = load_config(deploy_cfg)[0] + codebase_config = get_codebase_config(deploy_cfg) try: - task = deploy_cfg['codebase_config']['task'] + task = codebase_config['task'] except KeyError: return default + task = Task.get(task, default) return task +def get_codebase_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict: + """Get the codebase_config from the config. + + Args: + deploy_cfg (str | mmcv.Config): The path or content of config. + + Returns: + Dict : codebase config dict. + """ + deploy_cfg = load_config(deploy_cfg)[0] + codebase_config = deploy_cfg.get('codebase_config', {}) + return codebase_config + + def get_codebase(deploy_cfg: Union[str, mmcv.Config], default: Optional[str] = None) -> Codebase: """Get the codebase from the config. @@ -64,11 +80,12 @@ def get_codebase(deploy_cfg: Union[str, mmcv.Config], Codebase : An enumeration denotes the codebase type. """ - deploy_cfg = load_config(deploy_cfg)[0] + codebase_config = get_codebase_config(deploy_cfg) try: - codebase = deploy_cfg['codebase_config']['type'] + codebase = codebase_config['type'] except KeyError: return default + codebase = Codebase.get(codebase, default) return codebase diff --git a/tests/test_codebase/test_mmdet/test_object_detection_model.py b/tests/test_codebase/test_mmdet/test_object_detection_model.py index 3757a8d3d5..f14200f8f9 100644 --- a/tests/test_codebase/test_mmdet/test_object_detection_model.py +++ b/tests/test_codebase/test_mmdet/test_object_detection_model.py @@ -7,6 +7,7 @@ import pytest import torch +import mmdeploy.backend.ncnn as ncnn_apis import mmdeploy.backend.onnxruntime as ort_apis from mmdeploy.codebase.mmdet.deploy.object_detection_model import End2EndModel from mmdeploy.utils import Backend @@ -53,8 +54,8 @@ def setup_class(cls): 'output_names': ['dets', 'labels'] }}) - from mmdeploy.codebase.mmdet.deploy.object_detection_model \ - import End2EndModel + from mmdeploy.codebase.mmdet.deploy.object_detection_model import \ + End2EndModel cls.end2end_model = End2EndModel(Backend.ONNXRUNTIME, [''], 'cpu', ['' for i in range(80)], deploy_cfg) @@ -114,8 +115,8 @@ def setup_class(cls): } }) - from mmdeploy.codebase.mmdet.deploy.object_detection_model \ - import End2EndModel + from mmdeploy.codebase.mmdet.deploy.object_detection_model import \ + End2EndModel cls.end2end_model = End2EndModel(Backend.ONNXRUNTIME, [''], 'cpu', ['' for i in range(80)], deploy_cfg) @@ -179,8 +180,8 @@ def setup_class(cls): deploy_cfg = mmcv.Config( dict(codebase_config=dict(post_processing=post_processing))) - from mmdeploy.codebase.mmdet.deploy.object_detection_model \ - import PartitionSingleStageModel + from mmdeploy.codebase.mmdet.deploy.object_detection_model import \ + PartitionSingleStageModel cls.model = PartitionSingleStageModel( Backend.ONNXRUNTIME, [''], 'cpu', ['' for i in range(80)], @@ -293,8 +294,8 @@ def setup_class(cls): outputs=outputs, model_cfg=model_cfg, deploy_cfg=deploy_cfg) # replace original function in PartitionTwoStageModel - from mmdeploy.codebase.mmdet.deploy.object_detection_model \ - import PartitionTwoStageModel + from mmdeploy.codebase.mmdet.deploy.object_detection_model import \ + PartitionTwoStageModel cls.model = PartitionTwoStageModel( Backend.ONNXRUNTIME, ['', ''], @@ -399,6 +400,7 @@ def partition1_postprocess(self, *args, **kwargs): @pytest.mark.parametrize('cfg', [data_cfg1, data_cfg2, data_cfg3, data_cfg4]) def test_get_classes_from_cfg(cfg): from mmdet.datasets import DATASETS + from mmdeploy.codebase.mmdet.deploy.object_detection_model import \ get_classes_from_config @@ -436,3 +438,42 @@ def test_build_object_detection_model(partition_type): detector = build_object_detection_model([''], model_cfg, deploy_cfg, 'cpu') assert isinstance(detector, End2EndModel) + + +@backend_checker(Backend.NCNN) +class TestNCNNEnd2EndModel: + + @classmethod + def setup_class(cls): + # force add backend wrapper regardless of plugins + from mmdeploy.backend.ncnn import NCNNWrapper + ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper}) + + # simplify backend inference + cls.wrapper = SwitchBackendWrapper(NCNNWrapper) + cls.outputs = { + 'output': torch.rand(1, 10, 6), + } + cls.wrapper.set(outputs=cls.outputs) + deploy_cfg = mmcv.Config({'onnx_config': {'output_names': ['output']}}) + model_cfg = mmcv.Config({}) + + from mmdeploy.codebase.mmdet.deploy.object_detection_model import \ + NCNNEnd2EndModel + cls.ncnn_end2end_model = NCNNEnd2EndModel(Backend.NCNN, ['', ''], + 'cpu', + ['' for i in range(80)], + model_cfg, deploy_cfg) + + @classmethod + def teardown_class(cls): + cls.wrapper.recover() + + @pytest.mark.parametrize('num_det', [10, 0]) + def test_forward_test(self, num_det): + self.outputs = { + 'output': torch.rand(1, num_det, 6), + } + imgs = torch.rand(1, 3, 64, 64) + results = self.ncnn_end2end_model.forward_test(imgs) + assert_det_results(results, 'NCNNEnd2EndModel') diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index be7b3e651e..7210e21b32 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -71,6 +71,16 @@ def test_get_task_type(self): assert util.get_task_type(correct_deploy_path) == Task.SUPER_RESOLUTION +class TestGetCodebaseConfig: + + def test_get_codebase_config_empty(self): + assert util.get_codebase_config(mmcv.Config(dict())) == {} + + def test_get_codebase_config(self): + codebase_config = util.get_codebase_config(correct_deploy_path) + assert isinstance(codebase_config, dict) and len(codebase_config) > 1 + + class TestGetCodebase: def test_get_codebase_none(self):