Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RKNN support. #865

Merged
merged 16 commits into from
Sep 6, 2022
9 changes: 9 additions & 0 deletions configs/_base_/backends/rknn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
backend_config = dict(
type='rknn',
common_config=dict(
mean_values=None,
std_values=None,
output_tensor_type=None,
target_platform='rk3588',
optimization_level=3),
quantization_config=dict(do_quantization=False, dataset=None))
5 changes: 5 additions & 0 deletions configs/mmcls/classification_rknn_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=None)
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
codebase_config = dict(model_type='rknn')
backend_config = dict(input_size_list=[[3, 224, 224]])
17 changes: 17 additions & 0 deletions configs/mmdet/detection/detection_rknn_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[640, 640])

codebase_config = dict(model_type='rknn')

backend_config = dict(input_size_list=[[3, 640, 640]])

partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True, # should always be set to True
partition_cfg=[
dict(
save_file='model.onnx', # name to save the partitioned onnx model
start=['detector_forward:input'], # [mark_name:input/output, ...]
end=['yolo_head:input']) # [mark_name:input/output, ...]
])
7 changes: 7 additions & 0 deletions configs/mmseg/segmentation_rknn_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[512, 512])

codebase_config = dict(model_type='rknn')

backend_config = dict(input_size_list=[[3, 512, 512]])
146 changes: 73 additions & 73 deletions docs/en/03-benchmark/supported_models.md

Large diffs are not rendered by default.

81 changes: 81 additions & 0 deletions docs/en/05-supported-backends/rknn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# RKNN support

This tutorial is based on Linux systems like Ubuntu-18.04 and Rockchip NPU like `rk3588`.

## Installation

It is recommended to create a virtual environment for the project.

1. get RKNN-Toolkit2 through:

```
git clone https://github.com/rockchip-linux/rknn-toolkit2
```

2. install RKNN python package following [official doc](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc).

3. reinstall MMDeploy from source following the [instructions](../01-how-to-build/build_from_source.md). Note that there are conflicts between the pip dependencies of MMDeploy and RKNN. Here is the suggested packages versions for python 3.6:

```
protobuf==3.19.4
onnx==1.8.0
onnxruntime==1.8.0
torch==1.8.0
torchvision==0.9.0
```

To work with models from [MMDetection](https://github.com/open-mmlab/mmdetection/blob/master/docs/get_started.md), you may need to install it additionally.

## Usage

Example:

```bash
python tools/deploy.py \
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
configs/mmdet/detection/detection_rknn_static.py \
/mmdetection_dir/mmdetection/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py \
/tmp/snapshots/yolov3_d53_mstrain-608_273e_coco_20210518_115020-a2c3acb8.pth \
tests/data/tiger.jpeg \
--work-dir ../deploy_result \
Copy link
Collaborator

@lvhan028 lvhan028 Sep 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--dump-info is not supported.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we did not support SDK stuff here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--dump-info is the precondition. Without it, how can SDK do testing?
It's not SDK's responsibility to get the deployment meta info because it has no idea how model converter does the postprocessing.

--device cpu
```

## Deployment config

With the deployment config, you can modify the `backend_config` for your preference. A example `backend_config` of mmclassification is shown as below:
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved

```python
backend_config = dict(
type='rknn',
common_config=dict(
mean_values=None,
std_values=None,
output_tensor_type=None,
target_platform='rk3588',
optimization_level=3),
quantization_config=dict(do_quantization=False, dataset=None),
input_size_list=[[3, 224, 224]])

```

The contents of `common_config` are for `rknn.config()`. The contents of `quantization_config` are used to control `rknn.build()`.

## Troubleshooting

- Quantization fails.

Empirically, RKNN require the inputs not normalized if `do_quantization` is set to `False`. Please modify the settings of `Normalize` in the `model_cfg` from

```python
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
```

to

```python
img_norm_cfg = dict(
mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
```

Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[123.675, 116.28, 103.53]` and `std_values=[58.395, 57.12, 57.375]`.
144 changes: 73 additions & 71 deletions docs/zh_cn/03-benchmark/supported_models.md

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions mmdeploy/apis/rknn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.rknn import is_available

__all__ = ['is_available']

if is_available():
from mmdeploy.backend.rknn.onnx2rknn import onnx2rknn as _onnx2rknn
from ..core import PIPELINE_MANAGER
onnx2rknn = PIPELINE_MANAGER.register_pipeline()(_onnx2rknn)

__all__ += ['onnx2rknn']
31 changes: 31 additions & 0 deletions mmdeploy/backend/rknn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import re
import subprocess


def is_available():
"""Check whether rknn is installed.

Returns:
bool: True if rknn package is installed.
"""
return importlib.util.find_spec('rknn') is not None


def device_available():
"""Check whether device available.

Returns:
bool: True if the device is available.
"""
ret = subprocess.check_output('adb devices', shell=True)
match = re.search(r'\\n\w+\\tdevice', str(ret))
return match is not None


__all__ = []

if is_available():
from .wrapper import RKNNWrapper
__all__ += ['RKNNWrapper']
82 changes: 82 additions & 0 deletions mmdeploy/backend/rknn/onnx2rknn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union

import mmcv
from rknn.api import RKNN

from mmdeploy.utils import (get_common_config, get_onnx_config,
get_partition_config, get_quantization_config,
get_root_logger, load_config)
from mmdeploy.utils.config_utils import get_backend_config


def onnx2rknn(onnx_model: str,
output_path: str,
deploy_cfg: Union[str, mmcv.Config],
dataset_file: Optional[str] = None,
**kwargs):
"""Convert ONNX to RKNN.

RKNN-Toolkit2 is a software development kit for users to perform model
conversion, inference and performance evaluation on PC and Rockchip
NPU platforms.

Args:
onnx_model (str): Input onnx model.
output_path (str): File path to save RKNN model.
device (str): A string specifying device, defaults to 'cuda:0'.
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
input_shapes (Sequence[Sequence[int]] | None): Shapes for PPLNN
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
optimization, default to None.

Examples:
>>> from mmdeploy.apis.rknn import from_onnx
>>>
>>> from_onnx(onnx_model = 'example.onnx',
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
output_file_prefix = 'example')
"""
logger = get_root_logger()
# load deploy_cfg if necessary
deploy_cfg = load_config(deploy_cfg)[0]

common_params = get_common_config(deploy_cfg)
# common_params.update(dict(mean_values=[0, 0, 0], std_values=[1, 1, 1]))
AllentDan marked this conversation as resolved.
Show resolved Hide resolved
onnx_params = get_onnx_config(deploy_cfg)
quantization_cfg = get_quantization_config(deploy_cfg)

input_names = onnx_params.get('input_names', None)
output_names = onnx_params.get('output_names', None)
input_size_list = get_backend_config(deploy_cfg).get(
'input_size_list', None)
# update output_names for partition models
if get_partition_config(deploy_cfg) is not None:
import onnx
_onnx_model = onnx.load(onnx_model)
output_names = [node.name for node in _onnx_model.graph.output]

rknn = RKNN(verbose=True)
rknn.config(**common_params)
ret = rknn.load_onnx(
model=onnx_model,
inputs=input_names,
input_size_list=input_size_list,
outputs=output_names)
if ret != 0:
logger.error('Load model failed!')
exit(ret)

dataset_cfg = quantization_cfg.get('dataset', None)
do_quantization = quantization_cfg.get('do_quantization', False)
if dataset_cfg is None and dataset_file is None:
do_quantization = False
logger.warning('no dataset passed in, quantization is skipped')
if dataset_file is None:
dataset_file = dataset_cfg
ret = rknn.build(do_quantization=do_quantization, dataset=dataset_file)
if ret != 0:
logger.error('Build model failed!')
exit(ret)

ret = rknn.export_rknn(output_path)
if ret != 0:
logger.error('Export rknn model failed!')
exit(ret)
66 changes: 66 additions & 0 deletions mmdeploy/backend/rknn/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence

import numpy as np
import torch
from rknn.api import RKNN

from mmdeploy.utils import Backend, get_root_logger
from mmdeploy.utils.timer import TimeCounter
from ..base import BACKEND_WRAPPER, BaseWrapper


@BACKEND_WRAPPER.register_module(Backend.RKNN.value)
class RKNNWrapper(BaseWrapper):
"""PPLNN wrapper for inference.

Args:
model (str): Path of input RKNN model file.
target_platform (str): Device to put model.
AllentDan marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> from mmdeploy.backend.pplnn import PPLNNWrapper
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
>>> import torch
>>>
>>> model = 'model.rknn'
>>> model = RKNNWrapper(model)
AllentDan marked this conversation as resolved.
Show resolved Hide resolved
>>> inputs = dict(input=torch.randn(1, 3, 224, 224))
>>> outputs = model(inputs)
>>> print(outputs)
"""

def __init__(self,
model: str,
common_config: Dict,
output_names: Optional[Sequence[str]] = None,
**kwargs):
logger = get_root_logger()
# Create RKNN object
self.rknn = RKNN(verbose=True)
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
self.rknn.load_rknn(model)
ret = self.rknn.init_runtime(target=common_config['target_platform'])
if ret != 0:
logger.error('Init runtime environment failed!')
exit(ret)
super().__init__(output_names)

def forward(self, inputs: Dict[str,
torch.Tensor]) -> Sequence[torch.Tensor]:
"""Run forward inference. Note that the shape of the input tensor is
NxCxHxW while RKNN only accepts the numpy inputs of NxHxWxC. There is a
permute operation outside RKNN inference.

Args:
inputs (Dict[str, torch.Tensor]): Input name and tensor pairs.

Return:
Sequence[torch.Tensor]: The output tensors.
"""
rknn_out = self.__rknnnn_execute(
[i.permute(0, 2, 3, 1).cpu().numpy() for i in inputs.values()])
return [torch.from_numpy(out) for out in rknn_out]

@TimeCounter.count_time(Backend.RKNN.value)
def __rknnnn_execute(self, inputs: Sequence[np.array]):
"""Run inference with RKNN."""
return self.rknn.inference(inputs)
9 changes: 8 additions & 1 deletion mmdeploy/codebase/base/backend_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from mmdeploy.utils import (SDK_TASK_MAP, Backend, get_backend_config,
get_ir_config, get_task_type)
get_common_config, get_ir_config, get_task_type)


class BaseBackendModel(torch.nn.Module, metaclass=ABCMeta):
Expand Down Expand Up @@ -106,6 +106,13 @@ def _build_wrapper(backend: Backend,
model=backend_files[0],
input_names=input_names,
output_names=output_names)
elif backend == Backend.RKNN:
from mmdeploy.backend.rknn import RKNNWrapper
common_config = get_common_config(deploy_cfg)
return RKNNWrapper(
model=backend_files[0],
common_config=common_config,
output_names=output_names)
elif backend == Backend.SNPE:
from mmdeploy.backend.snpe import SNPEWrapper
uri = None
Expand Down
19 changes: 19 additions & 0 deletions mmdeploy/codebase/mmcls/deploy/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,25 @@ def forward(self, img: List[torch.Tensor], *args, **kwargs) -> list:
return pred[np.argsort(pred[:, 0])][np.newaxis, :, 1]


@__BACKEND_MODEL.register_module('rknn')
class RKNNEnd2EndModel(End2EndModel):
"""RKNN inference class, converts RKNN output to mmcls format."""

def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \
List[np.ndarray]:
"""The interface for forward test.

Args:
imgs (torch.Tensor): Input image(s) in [N x C x H x W] format.

Returns:
List[np.ndarray]: A list of classification prediction.
"""
outputs = self.wrapper({self.input_name: imgs})
outputs = [out.numpy() for out in outputs]
return outputs


def get_classes_from_config(model_cfg: Union[str, mmcv.Config]):
"""Get class name from config.

Expand Down
3 changes: 3 additions & 0 deletions mmdeploy/codebase/mmdet/core/point_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.core.anchor.MlvlPointGenerator.single_level_grid_priors',
backend=Backend.TENSORRT.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.core.anchor.MlvlPointGenerator.single_level_grid_priors',
backend=Backend.RKNN.value)
def mlvl_point_generator__single_level_grid_priors__tensorrt(
ctx,
self,
Expand Down
Loading