diff --git a/configs/mmdet/detection/yolov3_partition_onnxruntime_static.py b/configs/mmdet/detection/yolov3_partition_onnxruntime_static.py new file mode 100644 index 0000000000..20e10a2562 --- /dev/null +++ b/configs/mmdet/detection/yolov3_partition_onnxruntime_static.py @@ -0,0 +1,12 @@ +_base_ = ['./detection_onnxruntime_static.py'] + +onnx_config = dict(input_shape=[608, 608]) +partition_config = dict( + type='yolov3_partition', + apply_marks=True, + partition_cfg=[ + dict( + save_file='yolov3.onnx', + start=['detector_forward:input'], + end=['yolo_head:input']) + ]) diff --git a/docs/en/01-how-to-build/build_from_docker.md b/docs/en/01-how-to-build/build_from_docker.md index 0157c92e04..4526d8205b 100644 --- a/docs/en/01-how-to-build/build_from_docker.md +++ b/docs/en/01-how-to-build/build_from_docker.md @@ -1,8 +1,8 @@ -## Use Docker Image +# Use Docker Image We provide two dockerfiles for CPU and GPU respectively. For CPU users, we install MMDeploy with ONNXRuntime, ncnn and OpenVINO backends. For GPU users, we install MMDeploy with TensorRT backend. Besides, users can install mmdeploy with different versions when building the docker image. -### Build docker image +## Build docker image For CPU users, we can build the docker image with the latest MMDeploy through: @@ -37,7 +37,7 @@ cd mmdeploy docker build docker/CPU/ -t mmdeploy:inside --build-arg USE_SRC_INSIDE=true ``` -### Run docker container +## Run docker container After building the docker image succeed, we can use `docker run` to launch the docker service. GPU docker image for example: @@ -45,7 +45,7 @@ After building the docker image succeed, we can use `docker run` to launch the d docker run --gpus all -it -p 8080:8081 mmdeploy:master-gpu ``` -### FAQs +## FAQs 1. CUDA error: the provided PTX was compiled with an unsupported toolchain: diff --git a/docs/en/02-how-to-run/write_config.md b/docs/en/02-how-to-run/write_config.md index 8eef707d29..92297d4de3 100644 --- a/docs/en/02-how-to-run/write_config.md +++ b/docs/en/02-how-to-run/write_config.md @@ -1,4 +1,4 @@ -## How to write config +# How to write config This tutorial describes how to write a config for model conversion and deployment. A deployment config includes `onnx config`, `codebase config`, `backend config`. @@ -24,11 +24,11 @@ This tutorial describes how to write a config for model conversion and deploymen -### 1. How to write onnx config +## 1. How to write onnx config Onnx config to describe how to export a model from pytorch to onnx. -#### Description of onnx config arguments +### Description of onnx config arguments - `type`: Type of config dict. Default is `onnx`. - `export_params`: If specified, all parameters will be exported. Set this to False if you want to export an untrained model. @@ -39,7 +39,7 @@ Onnx config to describe how to export a model from pytorch to onnx. - `output_names`: Names to assign to the output nodes of the graph. - `input_shape`: The height and width of input tensor to the model. -##### Example +### Example ```python onnx_config = dict( @@ -53,13 +53,13 @@ onnx_config = dict( input_shape=None) ``` -#### If you need to use dynamic axes +### If you need to use dynamic axes If the dynamic shape of inputs and outputs is required, you need to add dynamic_axes dict in onnx config. - `dynamic_axes`: Describe the dimensional information about input and output. -##### Example +#### Example ```python dynamic_axes={ @@ -79,28 +79,28 @@ If the dynamic shape of inputs and outputs is required, you need to add dynamic_ } ``` -### 2. How to write codebase config +## 2. How to write codebase config Codebase config part contains information like codebase type and task type. -#### Description of codebase config arguments +### Description of codebase config arguments - `type`: Model's codebase, including `mmcls`, `mmdet`, `mmseg`, `mmocr`, `mmedit`. - `task`: Model's task type, referring to [List of tasks in all codebases](#list-of-tasks-in-all-codebases). -##### Example +#### Example ```python codebase_config = dict(type='mmcls', task='Classification') ``` -### 3. How to write backend config +## 3. How to write backend config The backend config is mainly used to specify the backend on which model runs and provide the information needed when the model runs on the backend , referring to [ONNX Runtime](../05-supported-backends/onnxruntime.md), [TensorRT](../05-supported-backends/tensorrt.md), [ncnn](../05-supported-backends/ncnn.md), [PPLNN](../05-supported-backends/pplnn.md). - `type`: Model's backend, including `onnxruntime`, `ncnn`, `pplnn`, `tensorrt`, `openvino`. -#### Example +### Example ```python backend_config = dict( @@ -117,7 +117,7 @@ backend_config = dict( ]) ``` -### 4. A complete example of mmcls on TensorRT +## 4. A complete example of mmcls on TensorRT Here we provide a complete deployment config from mmcls on TensorRT. @@ -159,7 +159,7 @@ onnx_config = dict( input_shape=[224, 224]) ``` -### 5. The name rules of our deployment config +## 5. The name rules of our deployment config There is a specific naming convention for the filename of deployment config files. @@ -171,20 +171,12 @@ There is a specific naming convention for the filename of deployment config file - `backend name`: Backend's name. Note if you use the quantization function, you need to indicate the quantization type. Just like `tensorrt-int8`. - `dynamic or static`: Dynamic or static export. Note if the backend needs explicit shape information, you need to add a description of input size with `height x width` format. Just like `dynamic-512x1024-2048x2048`, it means that the min input shape is `512x1024` and the max input shape is `2048x2048`. -#### Example +### Example ```bash detection_tensorrt-int8_dynamic-320x320-1344x1344.py ``` -### 6. How to write model config +## 6. How to write model config According to model's codebase, write the model config file. Model's config file is used to initialize the model, referring to [MMClassification](https://github.com/open-mmlab/mmclassification/blob/master/docs/tutorials/config.md), [MMDetection](https://github.com/open-mmlab/mmdetection/blob/master/docs_zh-CN/tutorials/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/master/docs_zh-CN/tutorials/config.md), [MMOCR](https://github.com/open-mmlab/mmocr/tree/main/configs), [MMEditing](https://github.com/open-mmlab/mmediting/blob/master/docs_zh-CN/config.md). - -### 7. Reminder - -None - -### 8. FAQs - -None diff --git a/docs/en/06-developer-guide/add_test_units_for_backend_ops.md b/docs/en/06-developer-guide/add_test_units_for_backend_ops.md index 5f9fa300b3..8c517857b2 100644 --- a/docs/en/06-developer-guide/add_test_units_for_backend_ops.md +++ b/docs/en/06-developer-guide/add_test_units_for_backend_ops.md @@ -1,16 +1,16 @@ -## How to add test units for backend ops +# How to add test units for backend ops This tutorial introduces how to add unit test for backend ops. When you add a custom op under `backend_ops`, you need to add the corresponding test unit. Test units of ops are included in `tests/test_ops/test_ops.py`. -### Prerequisite +## Prerequisite - `Compile new ops`: After adding a new custom op, needs to recompile the relevant backend, referring to [build.md](../01-how-to-build/build_from_source.md). -### 1. Add the test program test_XXXX() +## 1. Add the test program test_XXXX() You can put unit test for ops in `tests/test_ops/`. Usually, the following program template can be used for your custom op. -#### example of ops unit test +### example of ops unit test ```python @pytest.mark.parametrize('backend', [TEST_TENSORRT, TEST_ONNXRT]) # 1.1 backend test class @@ -49,26 +49,26 @@ def test_roi_align(backend, save_dir=save_dir) ``` -#### 1.1 backend test class +### 1.1 backend test class We provide some functions and classes for difference backends, such as `TestOnnxRTExporter`, `TestTensorRTExporter`, `TestNCNNExporter`. -#### 1.2 set parameters of op +### 1.2 set parameters of op Set some parameters of op, such as ’pool_h‘, ’pool_w‘, ’spatial_scale‘, ’sampling_ratio‘ in roi_align. You can set multiple parameters to test op. -#### 1.3 op input data initialization +### 1.3 op input data initialization Initialization required input data. -#### 1.4 initialize op model to be tested +### 1.4 initialize op model to be tested The model containing custom op usually has two forms. - `torch model`: Torch model with custom operators. Python code related to op is required, refer to `roi_align` unit test. - `onnx model`: Onnx model with custom operators. Need to call onnx api to build, refer to `multi_level_roi_align` unit test. -#### 1.5 call the backend test class interface +### 1.5 call the backend test class interface Call the backend test class `run_and_validate` to run and verify the result output by the op on the backend. @@ -86,7 +86,7 @@ Call the backend test class `run_and_validate` to run and verify the result outp save_dir=None): ``` -##### Parameter Description +#### Parameter Description - `model`: Input model to be tested and it can be torch model or any other backend model. - `input_list`: List of test data, which is mapped to the order of input_names. @@ -99,7 +99,7 @@ Call the backend test class `run_and_validate` to run and verify the result outp - `expected_result`: Expected ground truth values for verification. - `save_dir`: The folder used to save the output files. -### 2. Test Methods +## 2. Test Methods Use pytest to call the test function to test ops. diff --git a/docs/en/06-developer-guide/partition_model.md b/docs/en/06-developer-guide/partition_model.md new file mode 100644 index 0000000000..eadf63766d --- /dev/null +++ b/docs/en/06-developer-guide/partition_model.md @@ -0,0 +1,89 @@ +# How to get partitioned ONNX models + +MMDeploy supports exporting PyTorch models to partitioned onnx models. With this feature, users can define their partition policy and get partitioned onnx models at ease. In this tutorial, we will briefly introduce how to support partition a model step by step. In the example, we would break YOLOV3 model into two parts and extract the first part without the post-processing (such as anchor generating and NMS) in the onnx model. + +## Step 1: Mark inputs/outpupts + +To support the model partition, we need to add `Mark` nodes in the ONNX model. This could be done with mmdeploy's `@mark` decorator. Note that to make the `mark` work, the marking operation should be included in a rewriting function. + +At first, we would mark the model input, which could be done by marking the input tensor `img` in the `forward` method of `BaseDetector` class, which is the parent class of all detector classes. Thus we name this marking point as `detector_forward` and mark the inputs as `input`. Since there could be three outputs for detectors such as `Mask RCNN`, the outputs are marked as `dets`, `labels`, and `masks`. The following code shows the idea of adding mark functions and calling the mark functions in the rewrite. For source code, you could refer to [mmdeploy/codebase/mmdet/models/detectors/base.py](https://github.com/open-mmlab/mmdeploy/blob/86a50e343a3a45d7bc2ba3256100accc4973e71d/mmdeploy/codebase/mmdet/models/detectors/base.py) + +```python +from mmdeploy.core import FUNCTION_REWRITER, mark + +@mark( + 'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) +def __forward_impl(ctx, self, img, img_metas=None, **kwargs): + ... + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.base.BaseDetector.forward') +def base_detector__forward(ctx, self, img, img_metas=None, **kwargs): + ... + # call the mark function + return __forward_impl(...) +``` + +Then, we have to mark the output feature of `YOLOV3Head`, which is the input argument `pred_maps` in `get_bboxes` method of `YOLOV3Head` class. We could add a internal function to only mark the `pred_maps` inside [`yolov3_head__get_bboxes`](https://github.com/open-mmlab/mmdeploy/blob/86a50e343a3a45d7bc2ba3256100accc4973e71d/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py#L14) function as following. + +```python +from mmdeploy.core import FUNCTION_REWRITER, mark + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes') +def yolov3_head__get_bboxes(ctx, + self, + pred_maps, + img_metas, + cfg=None, + rescale=False, + with_nms=True): + # mark pred_maps + @mark('yolo_head', inputs=['pred_maps']) + def __mark_pred_maps(pred_maps): + return pred_maps + pred_maps = __mark_pred_maps(pred_maps) + ... +``` + +Note that `pred_maps` is a list of `Tensor` and it has three elements. Thus, three `Mark` nodes with op name as `pred_maps.0`, `pred_maps.1`, `pred_maps.2` would be added in the onnx model. + +## Step 2: Add partition config + +After marking necessary nodes that would be used to split the model, we could add a deployment config file `configs/mmdet/detection/yolov3_partition_onnxruntime_static.py`. If you are not familiar with how to write config, you could check [write_config.md](../02-how-to-run/write_config.md). + +In the config file, we need to add `partition_config`. The key part is `partition_cfg`, which contains elements of dict that designates the start nodes and end nodes of each model segments. Since we only want to keep `YOLOV3` without post-processing, we could set the `start` as `['detector_forward:input']`, and `end` as `['yolo_head:input']`. Note that `start` and `end` can have multiple marks. + +```python +_base_ = ['./detection_onnxruntime_static.py'] + +onnx_config = dict(input_shape=[608, 608]) +partition_config = dict( + type='yolov3_partition', # the partition policy name + apply_marks=True, # should always be set to True + partition_cfg=[ + dict( + save_file='yolov3.onnx', # filename to save the partitioned onnx model + start=['detector_forward:input'], # [mark_name:input/output, ...] + end=['yolo_head:input']) # [mark_name:input/output, ...] + ]) + +``` + +## Step 3: Get partitioned onnx models + +Once we have marks of nodes and the deployment config with `parition_config` being set properly, we could use the [tool](../useful_tools.md) `torch2onnx` to export the model to onnx and get the partition onnx files. + +```shell +python tools/torch2onnx.py \ +configs/mmdet/detection/yolov3_partition_onnxruntime_static.py \ +../mmdetection/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py \ +https://download.openmmlab.com/mmdetection/v2.0/yolo/yolov3_d53_mstrain-608_273e_coco/yolov3_d53_mstrain-608_273e_coco_20210518_115020-a2c3acb8.pth \ +../mmdetection/demo/demo.jpg \ +--work-dir ./work-dirs/mmdet/yolov3/ort/partition +``` + +After run the script above, we would have the partitioned onnx file `yolov3.onnx` in the `work-dir`. You can use the visualization tool [netron](https://netron.app/) to check the model structure. + +With the partitioned onnx file, you could refer to [useful_tools.md](../useful_tools.md) to do the following procedures such as `onnx2ncnn`, `onnx2tensorrt`. diff --git a/docs/en/06-developer-guide/support_new_backend.md b/docs/en/06-developer-guide/support_new_backend.md index d5ccce4e4f..61dafa3870 100644 --- a/docs/en/06-developer-guide/support_new_backend.md +++ b/docs/en/06-developer-guide/support_new_backend.md @@ -1,8 +1,8 @@ -## How to support new backends +# How to support new backends MMDeploy supports a number of backend engines. We welcome the contribution of new backends. In this tutorial, we will introduce the general procedures to support a new backend in MMDeploy. -### Prerequisites +## Prerequisites Before contributing the codes, there are some requirements for the new backend that need to be checked: @@ -10,7 +10,7 @@ Before contributing the codes, there are some requirements for the new backend t - If the backend requires model files or weight files other than a ".onnx" file, a conversion tool that converts the ".onnx" file to model files and weight files is required. The tool can be a Python API, a script, or an executable program. - It is highly recommended that the backend provides a Python interface to load the backend files and inference for validation. -### Support backend conversion +## Support backend conversion The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" file directly, or converts the ".onnx" to its own format using the conversion tool. In this section, we will introduce the steps to support backend conversion. @@ -155,7 +155,7 @@ The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" fi 7. Add docstring and unit tests for new code :). -### Support backend inference +## Support backend inference Although the backend engines are usually implemented in C/C++, it is convenient for testing and debugging if the backend provides Python inference interface. We encourage the contributors to support backend inference in the Python interface of MMDeploy. In this section we will introduce the steps to support backend inference. @@ -230,7 +230,7 @@ Although the backend engines are usually implemented in C/C++, it is convenient 5. Add docstring and unit tests for new code :). -### Support new backends using MMDeploy as a third party +## Support new backends using MMDeploy as a third party Previous parts show how to add a new backend in MMDeploy, which requires changing its source codes. However, if we treat MMDeploy as a third party, the methods above are no longer efficient. To this end, adding a new backend requires us pre-install another package named `aenum`. We can install it directly through `pip install aenum`. diff --git a/docs/en/06-developer-guide/support_new_model.md b/docs/en/06-developer-guide/support_new_model.md index 7406808074..ae456a45b7 100644 --- a/docs/en/06-developer-guide/support_new_model.md +++ b/docs/en/06-developer-guide/support_new_model.md @@ -1,8 +1,8 @@ -## How to support new models +# How to support new models We provide several tools to support model conversion. -### Function Rewriter +## Function Rewriter The PyTorch neural network is written in python that eases the development of the algorithm. But the use of Python control flow and third-party libraries make it difficult to export the network to an intermediate representation. We provide a 'monkey patch' tool to rewrite the unsupported function to another one that can be exported. Here is an example: @@ -26,7 +26,7 @@ It is easy to use the function rewriter. Just add a decorator with arguments: The arguments are the same as the original function, except a context `ctx` as the first argument. The context provides some useful information such as the deployment config `ctx.cfg` and the original function (which has been overridden) `ctx.origin_func`. -### Module Rewriter +## Module Rewriter If you want to replace a whole module with another one, we have another rewriter as follows: @@ -66,7 +66,7 @@ Just like function rewriter, add a decorator with arguments: All instances of the module in the network will be replaced with instances of this new class. The original module and the deployment config will be passed as the first two arguments. -### Custom Symbolic +## Custom Symbolic The mappings between PyTorch and ONNX are defined in PyTorch with symbolic functions. The custom symbolic function can help us to bypass some ONNX nodes which are unsupported by inference engine. diff --git a/docs/en/index.rst b/docs/en/index.rst index ca1e50d6c2..015b7163da 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -72,6 +72,7 @@ You can switch between Chinese and English documents in the lower-left corner of 06-developer-guide/support_new_backend.md 06-developer-guide/add_test_units_for_backend_ops.md 06-developer-guide/test_rewritten_models.md + 06-developer-guide/partition_model.md .. toctree:: :maxdepth: 1 diff --git a/docs/en/useful_tools.md b/docs/en/useful_tools.md index 896a892637..83b8072420 100644 --- a/docs/en/useful_tools.md +++ b/docs/en/useful_tools.md @@ -12,7 +12,7 @@ python tools/torch2onnx.py \ ${MODEL_CFG} \ ${CHECKPOINT} \ ${INPUT_IMG} \ - ${OUTPUT} \ + --work-dir ${WORK_DIR} \ --device cpu \ --log-level INFO ``` @@ -23,7 +23,7 @@ python tools/torch2onnx.py \ - `model_cfg` : The path of model config file in OpenMMLab codebase. - `checkpoint` : The path of the model checkpoint file. - `img` : The path of the image file used to convert the model. -- `output` : The path of the output ONNX model. +- `--work-dir` : Directory to save output ONNX models Default is `./work-dir`. - `--device` : The device used for conversion. If not specified, it will be set to `cpu`. - `--log-level` : To set log level which in `'CRITICAL', 'FATAL', 'ERROR', 'WARN', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'`. If not specified, it will be set to `INFO`. diff --git a/docs/zh_cn/04-developer-guide/partition_model.md b/docs/zh_cn/04-developer-guide/partition_model.md new file mode 100644 index 0000000000..70843b5e1f --- /dev/null +++ b/docs/zh_cn/04-developer-guide/partition_model.md @@ -0,0 +1,85 @@ +# How to get partitioned ONNX models + +MMDeploy 支持将PyTorch模型导出到onnx模型并进行拆分得到多个onnx模型文件,用户可以自由的对模型图节点进行标记并根据这些标记的节点定制任意的onnx模型拆分策略。在这个教程中,我们将通过具体例子来展示如何进行onnx模型拆分。在这个例子中,我们的目标是将YOLOV3模型拆分成两个部分,保留不带后处理的onnx模型,丢弃包含Anchor生成,NMS的后处理部分。 + +## 步骤 1: 添加模型标记点 + +为了进行图拆分,我们定义了`Mark`类型op,标记模型导出的边界。在实现方法上,采用`mark`装饰器对函数的输入、输出`Tensor`打标记。需要注意的是,我们的标记函数需要在某个重写函数中执行才能生效。 + +为了对YOLOV3进行拆分,首先我们需要标记模型的输入。这里为了通用性,我们标记检测器父类`BaseDetector`的`forward`方法中的`img` `Tensor`,同时为了支持其他拆分方案,也对`forward`函数的输出进行了标记,分别是`dets`, `labels`和`masks`。下面的代码是截图[mmdeploy/codebase/mmdet/models/detectors/base.py](https://github.com/open-mmlab/mmdeploy/blob/86a50e343a3a45d7bc2ba3256100accc4973e71d/mmdeploy/codebase/mmdet/models/detectors/base.py)中的一部分,可以看出我们使用`mark`装饰器标记了`__forward_impl`函数的输入输出,并在重写函数`base_detector__forward`进行了调用,从而完成了对检测器输入的标记。 + +```python +from mmdeploy.core import FUNCTION_REWRITER, mark + +@mark( + 'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) +def __forward_impl(ctx, self, img, img_metas=None, **kwargs): + ... + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.base.BaseDetector.forward') +def base_detector__forward(ctx, self, img, img_metas=None, **kwargs): + ... + # call the mark function + return __forward_impl(...) +``` + +接下来,我们只需要对`YOLOV3Head`中最后一层输出特征`Tensor`进行标记就可以将整个`YOLOV3`模型拆分成两部分。通过查看`mmdet`源码我们可以知道`YOLOV3Head`的`get_bboxes`方法中输入参数`pred_maps`就是我们想要的拆分点,因此可以在重写函数[`yolov3_head__get_bboxes`](https://github.com/open-mmlab/mmdeploy/blob/86a50e343a3a45d7bc2ba3256100accc4973e71d/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py#L14)中添加内部函数对`pred_mapes`进行标记,具体参考如下示例代码。值得注意的是,输入参数`pred_maps`是由三个`Tensor`组成的列表,所以我们在onnx模型中添加了三个`Mark`标记节点。 + +```python +from mmdeploy.core import FUNCTION_REWRITER, mark + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes') +def yolov3_head__get_bboxes(ctx, + self, + pred_maps, + img_metas, + cfg=None, + rescale=False, + with_nms=True): + # mark pred_maps + @mark('yolo_head', inputs=['pred_maps']) + def __mark_pred_maps(pred_maps): + return pred_maps + pred_maps = __mark_pred_maps(pred_maps) + ... +``` + +## 步骤 2: 添加部署配置文件 + +在完成模型中节点标记之后,我们需要创建部署配置文件,我们假设部署后端是`onnxruntime`,并模型输入是固定尺寸`608x608`,因此添加文件`configs/mmdet/detection/yolov3_partition_onnxruntime_static.py`. 我们需要在配置文件中添加基本的配置信息如`onnx_config`,如何你还不熟悉如何添加配置文件,可以参考[write_config.md](../02-how-to-run/write_config.md). + +在这个部署配置文件中, 我们需要添加一个特殊的模型分段配置字段`partition_config`. 在模型分段配置中,我们可以可以给分段策略添加一个类型名称如`yolov3_partition`,设定`apply_marks=True`。在分段方式`partition_cfg`,我们需要指定每段模型的分割起始点`start`, 终止点`end`以及保存分段onnx的文件名。需要提醒的是,各段模型起始点`start`和终止点`end`是由多个标记节点`Mark`组成,例如`'detector_forward:input'`代表`detector_forward`标记处输入所产生的标记节点。配置文件具体内容参考如下代码: + +```python +_base_ = ['./detection_onnxruntime_static.py'] + +onnx_config = dict(input_shape=[608, 608]) +partition_config = dict( + type='yolov3_partition', # the partition policy name + apply_marks=True, # should always be set to True + partition_cfg=[ + dict( + save_file='yolov3.onnx', # filename to save the partitioned onnx model + start=['detector_forward:input'], # [mark_name:input/output, ...] + end=['yolo_head:input']) # [mark_name:input/output, ...] + ]) + +``` + +## 步骤 3: 拆分onnx模型 + +添加好节点标记和部署配置文件,我们可以使用`tools/torch2onnx.py`工具导出带有`Mark`标记的完成onnx模型并根据分段策略提取分段的onnx模型文件。我们可以执行如下脚本,得到不带后处理的`YOLOV3`onnx模型文件`yolov3.onnx`,同时输出文件中也包含了添加`Mark`标记的完整模型文件`end2end.onnx`。此外,用户可以使用网页版模型可视化工具[netron](https://netron.app/)来查看和验证输出onnx模型的结构是否正确。 + +```shell +python tools/torch2onnx.py \ +configs/mmdet/detection/yolov3_partition_onnxruntime_static.py \ +../mmdetection/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py \ +https://download.openmmlab.com/mmdetection/v2.0/yolo/yolov3_d53_mstrain-608_273e_coco/yolov3_d53_mstrain-608_273e_coco_20210518_115020-a2c3acb8.pth \ +../mmdetection/demo/demo.jpg \ +--work-dir ./work-dirs/mmdet/yolov3/ort/partition +``` + +当得到分段onnx模型之后,我们可以使用mmdeploy提供的其他工具如`onnx2ncnn`, `onnx2tensorrt`来进行后续的模型部署工作。 diff --git a/docs/zh_cn/04-developer-guide/support_new_backend.md b/docs/zh_cn/04-developer-guide/support_new_backend.md index 223271ecc5..b8b1a952b0 100644 --- a/docs/zh_cn/04-developer-guide/support_new_backend.md +++ b/docs/zh_cn/04-developer-guide/support_new_backend.md @@ -1,8 +1,8 @@ -## 如何支持新的后端 +# 如何支持新的后端 MMDeploy 支持了许多后端推理引擎,但我们依然非常欢迎新后端的贡献。在本教程中,我们将介绍在 MMDeploy 中支持新后端的一般过程。 -### 必要条件 +## 必要条件 在对 MMDeploy 添加新的后端引擎之前,需要先检查所要支持的新后端是否符合一些要求: @@ -10,7 +10,7 @@ MMDeploy 支持了许多后端推理引擎,但我们依然非常欢迎新后 - 如果后端需要“.onnx”文件以外的模型文件或权重文件,则需要添加将“.onnx”文件转换为模型文件或权重文件的转换工具,该工具可以是 Python API、脚本或可执行程序。 - 强烈建议新后端可提供 Python 接口来加载后端文件和推理以进行验证。 -### 支持后端转换 +## 支持后端转换 MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx”文件,或者使用转换工具将“.onnx”转换成自己的格式。在本节中,我们将介绍支持后端转换的步骤。 @@ -155,7 +155,7 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx” 7. 为新后端引擎代码添加相关注释和单元测试:). -### 支持后端推理 +## 支持后端推理 尽管后端引擎通常用C/C++实现,但如果后端提供Python推理接口,则测试和调试非常方便。我们鼓励贡献者在MMDeploy的Python接口中支持新后端推理。在本节中,我们将介绍支持后端推理的步骤。 @@ -230,7 +230,7 @@ MMDeploy 中的后端必须支持 ONNX,因此后端能直接加载“.onnx” 5. 为新后端引擎代码添加相关注释和单元测试 :). -### 将MMDeploy作为第三方库时添加新后端 +## 将MMDeploy作为第三方库时添加新后端 前面的部分展示了如何在 MMDeploy 中添加新的后端,这需要更改其源代码。但是,如果我们将 MMDeploy 视为第三方,则上述方法不再有效。为此,添加一个新的后端需要我们预先安装另一个名为 `aenum` 的包。我们可以直接通过`pip install aenum`进行安装。 diff --git a/docs/zh_cn/04-developer-guide/support_new_model.md b/docs/zh_cn/04-developer-guide/support_new_model.md index 3ee4c84206..47ab46d4ed 100644 --- a/docs/zh_cn/04-developer-guide/support_new_model.md +++ b/docs/zh_cn/04-developer-guide/support_new_model.md @@ -1,8 +1,8 @@ -## 如何支持新的模型 +# 如何支持新的模型 我们提供了多种工具来支持模型转换 -### 函数的重写器 +## 函数的重写器 PyTorch 神经网络是用 python 编写的,可以简化算法的开发。但与此同时 Python 的流程控制和第三方库会使得网络导出为中间语言的过程变得困难。为此我们提供了一个“MonKey path”工具将不支持的功能重写为另一个可支持中间语言导出的功能。下述是一个具体的使用例子: @@ -26,7 +26,7 @@ def repeat_static(ctx, input, *size): 可参照[这些样例代码](https://github.com/open-mmlab/mmdeploy/blob/master/mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py)。 -### 模型重载器 +## 模型重载器 如果您想用另一个模块替换整个模块,我们还有另一个重载器,如下所示: @@ -61,7 +61,7 @@ class SRCNNWrapper(nn.Module): 网络中模块的所有实例都将替换为这个新类的实例。原始模块和部署配置将作为前两个参数进行传递。 -### 符号函数重写 +## 符号函数重写 PyTorch 和 ONNX 之间的映射是通过 PyTorch 中的符号函数进行定义的。自定义符号函数可以帮助我们绕过一些推理引擎不支持的 ONNX 节点。 diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 4c2d629df3..0393d78f9c 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -40,6 +40,7 @@ 04-developer-guide/support_new_model.md 04-developer-guide/support_new_backend.md 04-developer-guide/do_regression_test.md + 04-developer-guide/partition_model.md .. toctree:: :maxdepth: 1 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py index 11acb61f61..e6125470d9 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py @@ -5,7 +5,7 @@ from mmdeploy.codebase.mmdet import (get_post_processing_params, multiclass_nms, pad_with_value_if_necessary) -from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.utils import Backend, is_dynamic_shape @@ -45,6 +45,13 @@ def yolov3_head__get_bboxes(ctx, Else: tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores """ + # mark pred_maps + @mark('yolo_head', inputs=['pred_maps']) + def __mark_pred_maps(pred_maps): + return pred_maps + + pred_maps = __mark_pred_maps(pred_maps) + is_dynamic_flag = is_dynamic_shape(ctx.cfg) num_levels = len(pred_maps) pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)] diff --git a/tools/torch2onnx.py b/tools/torch2onnx.py index e13110353b..1beffce9e9 100644 --- a/tools/torch2onnx.py +++ b/tools/torch2onnx.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import logging +import os import os.path as osp -from mmdeploy.apis import torch2onnx -from mmdeploy.utils import get_root_logger +from mmdeploy.apis import (extract_model, get_predefined_partition_cfg, + torch2onnx) +from mmdeploy.utils import (get_ir_config, get_partition_config, + get_root_logger, load_config) def parse_args(): @@ -13,7 +16,10 @@ def parse_args(): parser.add_argument('model_cfg', help='model config path') parser.add_argument('checkpoint', help='model checkpoint path') parser.add_argument('img', help='image used to convert model model') - parser.add_argument('output', help='output onnx path') + parser.add_argument( + '--work-dir', + default='./work-dir', + help='Directory to save output files.') parser.add_argument( '--device', help='device used for conversion', default='cpu') parser.add_argument( @@ -30,29 +36,49 @@ def main(): args = parse_args() logger = get_root_logger(log_level=args.log_level) - deploy_cfg_path = args.deploy_cfg - model_cfg_path = args.model_cfg - checkpoint_path = args.checkpoint - img = args.img - output_path = args.output - work_dir, save_file = osp.split(output_path) - device = args.device - - logger.info(f'torch2onnx: \n\tmodel_cfg: {model_cfg_path} ' - f'\n\tdeploy_cfg: {deploy_cfg_path}') - try: - torch2onnx( - img, - work_dir, - save_file, - deploy_cfg=deploy_cfg_path, - model_cfg=model_cfg_path, - model_checkpoint=checkpoint_path, - device=device) - logger.info('torch2onnx success.') - except Exception as e: - logger.error(e) - logger.error('torch2onnx failed.') + logger.info(f'torch2onnx: \n\tmodel_cfg: {args.model_cfg} ' + f'\n\tdeploy_cfg: {args.deploy_cfg}') + + os.makedirs(args.work_dir, exist_ok=True) + # load deploy_cfg + deploy_cfg = load_config(args.deploy_cfg)[0] + save_file = get_ir_config(deploy_cfg)['save_file'] + + torch2onnx( + args.img, + args.work_dir, + save_file, + deploy_cfg=args.deploy_cfg, + model_cfg=args.model_cfg, + model_checkpoint=args.checkpoint, + device=args.device) + + # partition model + partition_cfgs = get_partition_config(deploy_cfg) + + if partition_cfgs is not None: + if 'partition_cfg' in partition_cfgs: + partition_cfgs = partition_cfgs.get('partition_cfg', None) + else: + assert 'type' in partition_cfgs + partition_cfgs = get_predefined_partition_cfg( + deploy_cfg, partition_cfgs['type']) + + origin_ir_file = osp.join(args.work_dir, save_file) + for partition_cfg in partition_cfgs: + save_file = partition_cfg['save_file'] + save_path = osp.join(args.work_dir, save_file) + start = partition_cfg['start'] + end = partition_cfg['end'] + dynamic_axes = partition_cfg.get('dynamic_axes', None) + + extract_model( + origin_ir_file, + start, + end, + dynamic_axes=dynamic_axes, + save_file=save_path) + logger.info(f'torch2onnx finished. Results saved to {args.work_dir}') if __name__ == '__main__':