Skip to content

Commit

Permalink
[Feature] end2end yolov3 with ncnn (open-mmlab#248)
Browse files Browse the repository at this point in the history
* support yolov3 ncnn with Yolov3DetectionOutput

* update nms

* fix contiguous in ncnn wrapper

* remove padding to detectionoutput

* format cpp

* Revert "format cpp"

This reverts commit 54050b19cd80d2f8cd851d82a755fd2c8d6c779d.

* fix zero detection

* fix yapf

* onnx2ncnn.cpp

* fix ut

* fix isort

* fix clang-format

* format cpp

* resolve comments

* resolve comments

* fix ut of ncnnend2endmodel

* fix yapf

* fix return list;

Co-authored-by: hanrui1sensetime <[email protected]>
  • Loading branch information
RunningLeon and hanrui1sensetime authored Dec 6, 2021
1 parent 2f6f6f8 commit 9e82851
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 163 deletions.
34 changes: 34 additions & 0 deletions backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3551,6 +3551,8 @@ int main(int argc, char** argv) {
}
} else if (op == "Where") {
fprintf(pp, "%-16s", "Where");
} else if (op == "Yolov3DetectionOutput") {
fprintf(pp, "%-16s", "Yolov3DetectionOutput");
} else {
// TODO
fprintf(stderr, "%s not supported yet!\n", op.c_str());
Expand Down Expand Up @@ -5382,6 +5384,38 @@ int main(int argc, char** argv) {
fprintf(pp, ",%d", axes[i]);
}
}
} else if (op == "Yolov3DetectionOutput") {
int num_class = get_node_attr_i(node, "num_class");
int num_box = get_node_attr_i(node, "num_box");
float confidence_threshold =
get_node_attr_f(node, "confidence_threshold");
float nms_threshold = get_node_attr_f(node, "nms_threshold");
fprintf(pp, " 0=%d", num_class);
fprintf(pp, " 1=%d", num_box);
fprintf(pp, " 2=%e", confidence_threshold);
fprintf(pp, " 3=%e", nms_threshold);
std::vector<float> biases = get_node_attr_af(node, "biases");
if (biases.size() > 0) {
fprintf(pp, " -23304=%zu", biases.size());
for (int i = 0; i < (int)biases.size(); i++) {
fprintf(pp, ",%e", biases[i]);
}
}
std::vector<float> mask = get_node_attr_af(node, "mask");
if (mask.size() > 0) {
fprintf(pp, " -23305=%zu", mask.size());
for (int i = 0; i < (int)mask.size(); i++) {
fprintf(pp, ",%e", mask[i]);
}
}
std::vector<float> anchors_scale =
get_node_attr_af(node, "anchors_scale");
if (anchors_scale.size() > 0) {
fprintf(pp, " -23306=%zu", anchors_scale.size());
for (int i = 0; i < (int)anchors_scale.size(); i++) {
fprintf(pp, ",%e", anchors_scale[i]);
}
}
} else {
// TODO op specific param
}
Expand Down
1 change: 1 addition & 0 deletions configs/mmdet/_base_/base_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
codebase_config = dict(
type='mmdet',
task='ObjectDetection',
model_type='end2end',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005, # for YOLOv3
Expand Down
4 changes: 4 additions & 0 deletions configs/mmdet/detection/single-stage_ncnn_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/ncnn.py']

codebase_config = dict(model_type='ncnn_end2end')
onnx_config = dict(output_names=['detection_output'], input_shape=None)
39 changes: 17 additions & 22 deletions mmdeploy/backend/ncnn/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def forward(self, inputs: Dict[str,
"""
input_list = list(inputs.values())
batch_size = input_list[0].size(0)
assert batch_size == 1, 'Only batch_size=1 is supported!'
for input_tensor in input_list[1:]:
assert input_tensor.size(
0) == batch_size, 'All tensors should have same batch size'
Expand All @@ -89,29 +90,23 @@ def forward(self, inputs: Dict[str,
# create output dict
outputs = dict([name, [None] * batch_size] for name in output_names)

# run inference
for batch_id in range(batch_size):
# create extractor
ex = self._net.create_extractor()

# set inputs
for name, input_tensor in inputs.items():
data = input_tensor[batch_id].contiguous()
data = data.detach().cpu().numpy()
input_mat = ncnn.Mat(data)
ex.input(name, input_mat)

# get outputs
result = self.__ncnn_execute(
extractor=ex, output_names=output_names)
for name in output_names:
outputs[name][batch_id] = torch.from_numpy(
np.array(result[name]))

# stack outputs together
for name, output_tensor in outputs.items():
outputs[name] = torch.stack(output_tensor)
# create extractor
ex = self._net.create_extractor()
# set inputs
for name, input_tensor in inputs.items():
data = input_tensor[0].contiguous().cpu().numpy()
input_mat = ncnn.Mat(data)
ex.input(name, input_mat)

# get outputs
result = self.__ncnn_execute(extractor=ex, output_names=output_names)
for name in output_names:
mat = result[name]
# deal with special case
if mat.empty():
outputs[name] = None
continue
outputs[name] = torch.from_numpy(np.array(mat)).unsqueeze(0)
return outputs

@TimeCounter.count_time()
Expand Down
67 changes: 62 additions & 5 deletions mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from mmdeploy.backend.base import get_backend_file_count
from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.codebase.mmdet import get_post_processing_params, multiclass_nms
from mmdeploy.utils import (Backend, get_backend, get_onnx_config,
get_partition_config, load_config)
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
get_onnx_config, get_partition_config, load_config)


def __build_backend_model(partition_name: str, backend: Backend,
Expand Down Expand Up @@ -259,7 +259,7 @@ def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \
def show_result(self,
img: np.ndarray,
result: list,
win_name: str,
win_name: str = '',
show: bool = True,
score_thr: float = 0.3,
out_file=None):
Expand Down Expand Up @@ -516,6 +516,61 @@ class labels of shape [N, num_det].
return outputs


@__BACKEND_MODEL.register_module('ncnn_end2end')
class NCNNEnd2EndModel(End2EndModel):
"""NCNNEnd2EndModel.
End2end NCNN model inference class. Because it has DetectionOutput layer
and its output is different from original mmdet style of `dets`, `labels`.
Args:
model_file (str): The path of input model file.
class_names (Sequence[str]): A list of string specifying class names.
model_cfg: (str | mmcv.Config): Input model config.
deploy_cfg: (str | mmcv.Config): Input deployment config.
device_id (int): An integer represents device index.
"""

def __init__(self, backend: Backend, backend_files: Sequence[str],
device: str, class_names: Sequence[str],
model_cfg: Union[str, mmcv.Config],
deploy_cfg: Union[str, mmcv.Config], **kwargs):
assert backend == Backend.NCNN, f'only supported ncnn, but give \
{backend.value}'

super(NCNNEnd2EndModel,
self).__init__(backend, backend_files, device, class_names,
deploy_cfg, **kwargs)
# load cfg if necessary
model_cfg = load_config(model_cfg)[0]
self.model_cfg = model_cfg

def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> List:
"""Implement forward test.
Args:
imgs (torch.Tensor): Input image(s) in [N x C x H x W] format.
Returns:
list[np.ndarray]: dets of shape [N, num_det, 5] and
class labels of shape [N, num_det].
"""
_, _, H, W = imgs.shape
outputs = self.wrapper({'input': imgs})
for key, item in outputs.items():
if item is None:
return [np.zeros((1, 0, 6))]
out = self.wrapper.output_to_list(outputs)[0]
labels = out[:, :, 0] - 1
scales = torch.tensor([W, H, W, H]).reshape(1, 1, 4)
scores = out[:, :, 1:2]
boxes = out[:, :, 2:6] * scales
dets = torch.cat([boxes, scores], dim=2)
dets = dets.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
return [dets, labels]


def get_classes_from_config(model_cfg: Union[str, mmcv.Config], **kwargs):
"""Get class name from config.
Expand Down Expand Up @@ -566,11 +621,13 @@ def build_object_detection_model(model_files: Sequence[str],
backend = get_backend(deploy_cfg)
class_names = get_classes_from_config(model_cfg)

# Default Config is 'end2end'
partition_type = 'end2end'
partition_config = get_partition_config(deploy_cfg)
if partition_config is not None:
partition_type = partition_config.get('type', None)
else:
codebase_config = get_codebase_config(deploy_cfg)
# Default Config is 'end2end'
partition_type = codebase_config.get('model_type', 'end2end')

backend_detector = __BACKEND_MODEL.build(
partition_type,
Expand Down
Loading

0 comments on commit 9e82851

Please sign in to comment.