diff --git a/.github/workflows/backend-ncnn.yml b/.github/workflows/backend-ncnn.yml index f5508621e0..1a387e3fe1 100644 --- a/.github/workflows/backend-ncnn.yml +++ b/.github/workflows/backend-ncnn.yml @@ -1,4 +1,4 @@ -name: backend +name: backend-ncnn on: push: @@ -23,7 +23,6 @@ jobs: matrix: python-version: [3.7] torch: [1.9.0] - mmcv: [1.4.2] include: - torch: 1.9.0 torch_version: torch1.9 diff --git a/.github/workflows/backend-snpe.yml b/.github/workflows/backend-snpe.yml new file mode 100644 index 0000000000..2bd5b5be21 --- /dev/null +++ b/.github/workflows/backend-snpe.yml @@ -0,0 +1,60 @@ +name: backend-snpe + +on: + push: + paths-ignore: + - "demo/**" + - "tools/**" + + pull_request: + paths-ignore: + - "demo/**" + - "tools/**" + - "docs/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build_sdk_demo: + runs-on: ubuntu-18.04 + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + submodules: 'recursive' + - name: update + run: sudo apt update + - name: Install dependencies + run: | + sudo apt install wget libprotobuf-dev protobuf-compiler + sudo apt update + sudo apt install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libxrender-dev libc++1-9 libc++abi1-9 + sudo add-apt-repository ppa:ignaciovizzo/opencv3-nonfree + sudo apt install libopencv-dev + pkg-config --libs opencv + - name: Install snpe + run: | + wget https://media.githubusercontent.com/media/tpoisonooo/mmdeploy_snpe_testdata/main/snpe-1.59.tar.gz + tar xf snpe-1.59.tar.gz + pushd snpe-1.59.0.3230 + pwd + popd + - name: Build SDK Demo with SNPE backend + run: | + mkdir -p build && pushd build + export SNPE_ROOT=/home/runner/work/mmdeploy/mmdeploy/snpe-1.59.0.3230 + export LD_LIBRARY_PATH=${SNPE_ROOT}/lib/x86_64-linux-clang:${LD_LIBRARY_PATH} + export MMDEPLOY_SNPE_X86_CI=1 + cmake .. -DCMAKE_CXX_COMPILER=g++-7 -DMMDEPLOY_SHARED_LIBS=ON -DMMDEPLOY_BUILD_SDK=ON -DMMDEPLOY_BUILD_SDK_PYTHON_API=OFF -DMMDEPLOY_TARGET_DEVICES=cpu -DMMDEPLOY_TARGET_BACKENDS=snpe -DMMDEPLOY_CODEBASES=all + make -j2 + make install + pushd install/example + mkdir build && pushd build + cmake .. -DMMDeploy_DIR=${PWD}/../../lib/cmake/MMDeploy + make -j2 + ls ./* + popd + popd + popd diff --git a/.gitignore b/.gitignore index 0c5d08c9ee..7dea9d9ea1 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,7 @@ bin/ mmdeploy/backend/ncnn/onnx2ncnn /mmdeploy-* + +# snpe +grpc-cpp-plugin +service/snpe/grpc_cpp_plugin diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0bb388f282..236e01a0cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,6 +3,7 @@ repos: rev: 4.0.1 hooks: - id: flake8 + args: ["--exclude=*/client/inference_pb2.py,*/client/inference_pb2_grpc.py"] - repo: https://github.com/PyCQA/isort rev: 5.10.1 hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index d2a8f13d61..db0be758e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,6 +113,7 @@ if (MMDEPLOY_BUILD_SDK) if (NOT MMDEPLOY_SHARED_LIBS) mmdeploy_add_deps(pplnn BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS pplnn) endif () + mmdeploy_add_deps(snpe BACKENDS ${MMDEPLOY_TARGET_BACKENDS} DEPS snpe) include(CMakePackageConfigHelpers) # generate the config file that is includes the exports diff --git a/README.md b/README.md index ee24307bb4..ab966e8bed 100644 --- a/README.md +++ b/README.md @@ -55,9 +55,9 @@ The currently supported codebases and models are as follows, and more will be in Models can be exported and run in the following backends, and more will be compatible -| ONNX Runtime | TensorRT | ppl.nn | ncnn | OpenVINO | LibTorch | more | -| ------------ | -------- | ------ | ---- | -------- | -------- | ---------------------------------------------- | -| ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | [benchmark](docs/en/03-benchmark/benchmark.md) | +| ONNX Runtime | TensorRT | ppl.nn | ncnn | OpenVINO | LibTorch | snpe | more | +| ------------ | -------- | ------ | ---- | -------- | -------- | ---- | ---------------------------------------------- | +| ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | [benchmark](docs/en/03-benchmark/benchmark.md) | ### Efficient and scalable C/C++ SDK Framework @@ -73,6 +73,7 @@ Please read [getting_started.md](docs/en/get_started.md) for the basic usage of - [Build for Win10](docs/en/01-how-to-build/windows.md) - [Build for Android](docs/en/01-how-to-build/android.md) - [Build for Jetson](docs/en/01-how-to-build/jetsons.md) + - [Build for SNPE](docs/en/01-how-to-build/snpe.md) - User Guide - [How to convert model](docs/en/02-how-to-run/convert_model.md) - [How to write config](docs/en/02-how-to-run/write_config.md) diff --git a/README_zh-CN.md b/README_zh-CN.md index 5718b47639..e77b4da141 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -53,9 +53,9 @@ MMDeploy 是 [OpenMMLab](https://openmmlab.com/) 模型部署工具箱,**为 ### 支持多种推理后端 -| ONNX Runtime | TensorRT | ppl.nn | ncnn | OpenVINO | more | -| ------------ | -------- | ------ | ---- | -------- | ------------------------------------------------- | -| ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | [benchmark](docs/zh_cn/03-benchmark/benchmark.md) | +| ONNX Runtime | TensorRT | ppl.nn | ncnn | OpenVINO | LibTorch | snpe | more | +| ------------ | -------- | ------ | ---- | -------- | -------- | ---- | ------------------------------------------------- | +| ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | [benchmark](docs/zh_cn/03-benchmark/benchmark.md) | ### SDK 可高度定制化 @@ -71,6 +71,7 @@ MMDeploy 是 [OpenMMLab](https://openmmlab.com/) 模型部署工具箱,**为 - [Build for Win10](docs/zh_cn/01-how-to-build/windows.md) - [Build for Android](docs/zh_cn/01-how-to-build/android.md) - [Build for Jetson](docs/en/01-how-to-build/jetsons.md) + - [Build for SNPE](docs/zh_cn/01-how-to-build/snpe.md) - 使用 - [把模型转换到推理 Backend](docs/zh_cn/02-how-to-run/convert_model.md) - [配置转换参数](docs/zh_cn/02-how-to-run/write_config.md) diff --git a/configs/_base_/backends/snpe.py b/configs/_base_/backends/snpe.py new file mode 100644 index 0000000000..a96bee9939 --- /dev/null +++ b/configs/_base_/backends/snpe.py @@ -0,0 +1 @@ +backend_config = dict(type='snpe') diff --git a/configs/mmcls/classification_snpe_static.py b/configs/mmcls/classification_snpe_static.py new file mode 100644 index 0000000000..f80140a3ac --- /dev/null +++ b/configs/mmcls/classification_snpe_static.py @@ -0,0 +1,3 @@ +_base_ = ['./classification_static.py', '../_base_/backends/snpe.py'] + +onnx_config = dict(input_shape=None) diff --git a/configs/mmedit/super-resolution/super-resolution_snpe_static-256x256.py b/configs/mmedit/super-resolution/super-resolution_snpe_static-256x256.py new file mode 100644 index 0000000000..2d1291646f --- /dev/null +++ b/configs/mmedit/super-resolution/super-resolution_snpe_static-256x256.py @@ -0,0 +1,2 @@ +_base_ = ['./super-resolution_static.py', '../../_base_/backends/snpe.py'] +onnx_config = dict(input_shape=[256, 256]) diff --git a/configs/mmocr/text-detection/text-detection_snpe_static.py b/configs/mmocr/text-detection/text-detection_snpe_static.py new file mode 100644 index 0000000000..a47ef9464d --- /dev/null +++ b/configs/mmocr/text-detection/text-detection_snpe_static.py @@ -0,0 +1,3 @@ +_base_ = ['./text-detection_static.py', '../../_base_/backends/snpe.py'] + +onnx_config = dict(input_shape=None) diff --git a/configs/mmpose/pose-detection_snpe_static-256x256.py b/configs/mmpose/pose-detection_snpe_static-256x256.py new file mode 100644 index 0000000000..4b2e6791d0 --- /dev/null +++ b/configs/mmpose/pose-detection_snpe_static-256x256.py @@ -0,0 +1,3 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/snpe.py'] + +onnx_config = dict(input_shape=[256, 256]) diff --git a/configs/mmseg/segmentation_snpe_static-512x1024.py b/configs/mmseg/segmentation_snpe_static-512x1024.py new file mode 100644 index 0000000000..7def73ce76 --- /dev/null +++ b/configs/mmseg/segmentation_snpe_static-512x1024.py @@ -0,0 +1,3 @@ +_base_ = ['./segmentation_static.py', '../_base_/backends/snpe.py'] + +onnx_config = dict(input_shape=[1024, 512]) diff --git a/csrc/mmdeploy/net/CMakeLists.txt b/csrc/mmdeploy/net/CMakeLists.txt index a7cd00d3de..3b42740c27 100644 --- a/csrc/mmdeploy/net/CMakeLists.txt +++ b/csrc/mmdeploy/net/CMakeLists.txt @@ -22,5 +22,9 @@ if ("openvino" IN_LIST MMDEPLOY_TARGET_BACKENDS) add_subdirectory(openvino) endif () +if ("snpe" IN_LIST MMDEPLOY_TARGET_BACKENDS) + add_subdirectory(snpe) +endif () + mmdeploy_add_module(${PROJECT_NAME} net_module.cpp) add_library(mmdeploy::net_module ALIAS ${PROJECT_NAME}) diff --git a/csrc/mmdeploy/net/snpe/CMakeLists.txt b/csrc/mmdeploy/net/snpe/CMakeLists.txt new file mode 100644 index 0000000000..2f8af24dc3 --- /dev/null +++ b/csrc/mmdeploy/net/snpe/CMakeLists.txt @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +project(mmdeploy_snpe_net) + +add_library(snpe SHARED IMPORTED) + +if(NOT DEFINED ENV{MMDEPLOY_SNPE_X86_CI}) + set(sub_dir "aarch64-android-clang6.0") +else() + set(sub_dir "x86_64-linux-clang") +endif() + +if (NOT EXISTS $ENV{SNPE_ROOT}/lib/${sub_dir}/) + message(ERROR "SNPE_ROOT directory not exist: $ENV{SNPE_ROOT}/lib/${sub_dir}/") +endif() +message(STATUS "SNPE lib directory $ENV{SNPE_ROOT}/lib/${sub_dir}/") + +set_target_properties(snpe PROPERTIES + IMPORTED_LOCATION "$ENV{SNPE_ROOT}/lib/${sub_dir}/libSNPE.so" + INTERFACE_INCLUDE_DIRECTORIES "$ENV{SNPE_ROOT}/include/zdl" +) + +mmdeploy_add_module(${PROJECT_NAME} snpe_net.cpp) +target_link_libraries(${PROJECT_NAME} PRIVATE snpe) +add_library(mmdeploy::snpe_net ALIAS ${PROJECT_NAME}) diff --git a/csrc/mmdeploy/net/snpe/snpe_net.cpp b/csrc/mmdeploy/net/snpe/snpe_net.cpp new file mode 100644 index 0000000000..5fb05b87f3 --- /dev/null +++ b/csrc/mmdeploy/net/snpe/snpe_net.cpp @@ -0,0 +1,262 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "snpe_net.h" + +#include "mmdeploy/core/logger.h" +#include "mmdeploy/core/model.h" +#include "mmdeploy/core/utils/formatter.h" + +namespace mmdeploy { + +SNPENet::~SNPENet() {} + +std::string SNPENet::ShapeStr(zdl::DlSystem::ITensor* pTensor) { + std::string str; + + str += "["; + auto shape = pTensor->getShape(); + for (int i = 0; i < shape.rank(); ++i) { + str += std::to_string(shape[i]); + str += ","; + } + str += ']'; + return str; +} + +void SNPENet::Build(std::unique_ptr& container, + zdl::DlSystem::Runtime_t runtime, zdl::DlSystem::RuntimeList runtimeList, + bool useUserSuppliedBuffers, zdl::DlSystem::PlatformConfig platformConfig) { + zdl::SNPE::SNPEBuilder snpeBuilder(container.get()); + + if (runtimeList.empty()) { + runtimeList.add(runtime); + } + + snpe_ = + snpeBuilder.setOutputLayers({}) + .setRuntimeProcessorOrder(runtimeList) + .setUseUserSuppliedBuffers(useUserSuppliedBuffers) + .setPlatformConfig(platformConfig) + .setPerformanceProfile(zdl::DlSystem::PerformanceProfile_t::SUSTAINED_HIGH_PERFORMANCE) + .build(); + return; +} + +void SNPENet::copy_output(const zdl::DlSystem::ITensor* from, Tensor& to) { + auto hwc_to_chw = [](const zdl::DlSystem::TensorShape& shape) -> bool { + if (shape.rank() != 4 || (shape[1] == 1 && shape[2] > 1 && shape[3] > 1)) { + return false; + } + return true; + }; + + auto output_shape = from->getShape(); + + if (to.size() != from->getSize()) { + TensorShape tensor_shape; + for (int j = 0; j < output_shape.rank(); ++j) { + tensor_shape.push_back(output_shape[j]); + } + + if (hwc_to_chw(output_shape)) { + auto tmp = output_shape[3]; + output_shape[3] = output_shape[1]; + output_shape[1] = tmp; + } + to.Reshape(tensor_shape); + } + + float* pto = to.data(); + + if (output_shape.rank() != 4 || + (output_shape[1] == 1 && output_shape[2] > 1 && output_shape[3] > 1)) { + // skip [1,1,w>1,h>1] for segmentation task + for (auto it = from->cbegin(); it != from->cend(); ++it, ++pto) { + *pto = *it; + } + } else { + const int channel = output_shape[1]; + const int panel = output_shape[2] * output_shape[3]; + + int i = 0; + // HWC to CHW + for (auto it = from->cbegin(); it != from->cend(); ++it, ++i) { + int channel_idx = i % channel; + int panel_idx = i / channel; + pto[channel_idx * panel + panel_idx] = *it; + } + } + return; +} + +void SNPENet::copy_input(const Tensor& from, zdl::DlSystem::ITensor* to) { + if (from.size() != to->getSize()) { + MMDEPLOY_ERROR("input tensor size not match"); + return; + } + + const float* pfrom = from.data(); + + auto input_shape = to->getShape(); + if (input_shape.rank() == 4) { + const int channel = input_shape[3]; + const int panel = input_shape[1] * input_shape[2]; + + int i = 0; + // CHW to HWC + for (auto it = to->begin(); it != to->end(); ++it, ++i) { + int channel_index = i % channel; + int panel_index = (i / channel) % panel; + + *it = pfrom[channel_index * panel + panel_index]; + } + + } else { + for (auto it = to->begin(); it != to->end(); ++it, ++pfrom) { + *it = *pfrom; + } + } +} + +Result SNPENet::Init(const Value& args) { + auto& context = args["context"]; + device_ = context["device"].get(); + stream_ = context["stream"].get(); + if (!device_.is_host()) { + return Status(eNotSupported); + } + + auto name = args["name"].get(); + auto model = context["model"].get(); + OUTCOME_TRY(auto config, model.GetModelConfig(name)); + + std::string content; + OUTCOME_TRY(content, model.ReadFile(config.net)); + char* model_ptr = const_cast(content.data()); + container_ = + zdl::DlContainer::IDlContainer::open(reinterpret_cast(model_ptr), content.size()); + if (container_ == nullptr) { + MMDEPLOY_ERROR("Load .dlc failed: {}", config.net); + return Status(eInvalidArgument); + } + + zdl::DlSystem::Runtime_t runtime = zdl::DlSystem::Runtime_t::GPU; + if (!zdl::SNPE::SNPEFactory::isRuntimeAvailable(runtime)) { + MMDEPLOY_WARN("Selected runtime not present. Falling back to CPU.\n"); + runtime = zdl::DlSystem::Runtime_t::CPU; + } + + zdl::DlSystem::RuntimeList runtimeList; + // Add CPU backend to support fallback + runtimeList.add(zdl::DlSystem::Runtime_t::CPU); + runtimeList.add(runtime); + zdl::DlSystem::PlatformConfig platformConfig; + Build(container_, runtime, runtimeList, false, platformConfig); + + // init internal input tensor list + const auto& inputTensorNamesRef = snpe_->getInputTensorNames(); + const auto& inputTensorNames = *inputTensorNamesRef; + inputs_internal_.resize(inputTensorNames.size()); + + for (int i = 0; i < inputTensorNames.size(); ++i) { + const auto& inputShape_opt = snpe_->getInputDimensions(inputTensorNames.at(i)); + const auto& inputShape = *inputShape_opt; + + inputs_internal_[i] = zdl::SNPE::SNPEFactory::getTensorFactory().createTensor(inputShape); + + std::string info = + std::string(inputTensorNames.at(i)) + " shape: " + ShapeStr(inputs_internal_[i].get()); + MMDEPLOY_INFO(info); + + input_tensor_map_.add(inputTensorNames.at(i), inputs_internal_[i].get()); + + input_tensors_.emplace_back(TensorDesc{ + Device("cpu"), + DataType::kFLOAT, + {}, + std::string(inputTensorNames.at(i)), + }); + } + + const auto& outputTensorNamesRef = snpe_->getOutputTensorNames(); + const auto& outputTensorNames = *outputTensorNamesRef; + for (int i = 0; i < outputTensorNames.size(); ++i) { + output_tensors_.emplace_back(TensorDesc{ + Device("cpu"), + DataType::kFLOAT, + {}, + std::string(outputTensorNames.at(i)), + }); + } + + return success(); +} + +Result SNPENet::Deinit() { return success(); } + +Result SNPENet::Reshape(Span input_shapes) { + for (size_t i = 0; i < input_shapes.size(); ++i) { + input_tensors_[i].Reshape(input_shapes[i]); + } + return success(); +} + +Result> SNPENet::GetInputTensors() { return input_tensors_; } + +Result> SNPENet::GetOutputTensors() { return output_tensors_; } + +Result SNPENet::Forward() { + OUTCOME_TRY(stream_.Wait()); + + { + // copy input to itensor buffer + for (auto& tensor : input_tensors_) { + const auto& name = tensor.desc().name; + auto pbuffer = input_tensor_map_.getTensor(name.c_str()); + + copy_input(tensor, pbuffer); + } + } + + // A tensor map for SNPE execution outputs + zdl::DlSystem::TensorMap output_map; + { + // real inference + bool success = snpe_->execute(input_tensor_map_, output_map); + if (!success) { + MMDEPLOY_ERROR("snpe Inference error: {}", std::string(zdl::DlSystem::getLastErrorString())); + return Status(eFail); + } + } + + { + // extract output buffer to tensor + auto names = output_map.getTensorNames(); + for (size_t i = 0; i < names.size(); ++i) { + const zdl::DlSystem::ITensor* pbuffer = output_map.getTensor(names.at(i)); + + auto& tensor = output_tensors_[i]; + copy_output(pbuffer, tensor); + } + } + return success(); +} + +class SNPENetCreator : public Creator { + public: + const char* GetName() const override { return "snpe"; } + int GetVersion() const override { return 0; } + std::unique_ptr Create(const Value& args) override { + auto p = std::make_unique(); + if (auto r = p->Init(args)) { + return p; + } else { + MMDEPLOY_ERROR("error creating SNPENet: {}", r.error().message().c_str()); + return nullptr; + } + } +}; + +REGISTER_MODULE(Net, SNPENetCreator); + +} // namespace mmdeploy diff --git a/csrc/mmdeploy/net/snpe/snpe_net.h b/csrc/mmdeploy/net/snpe/snpe_net.h new file mode 100644 index 0000000000..4058558613 --- /dev/null +++ b/csrc/mmdeploy/net/snpe/snpe_net.h @@ -0,0 +1,61 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_SRC_NET_SNPE_SNPE_NET_H_ +#define MMDEPLOY_SRC_NET_SNPE_SNPE_NET_H_ + +#include +#include +#include + +#include "DiagLog/IDiagLog.hpp" +#include "DlContainer/IDlContainer.hpp" +#include "DlSystem/DlEnums.hpp" +#include "DlSystem/DlError.hpp" +#include "DlSystem/ITensorFactory.hpp" +#include "DlSystem/IUserBuffer.hpp" +#include "DlSystem/PlatformConfig.hpp" +#include "DlSystem/RuntimeList.hpp" +#include "DlSystem/UserBufferMap.hpp" +#include "SNPE/SNPE.hpp" +#include "SNPE/SNPEBuilder.hpp" +#include "SNPE/SNPEFactory.hpp" +#include "mmdeploy/core/net.h" + +namespace mmdeploy { + +class SNPENet : public Net { + public: + ~SNPENet() override; + Result Init(const Value& args) override; + Result Deinit() override; + Result> GetInputTensors() override; + Result> GetOutputTensors() override; + Result Reshape(Span input_shapes) override; + Result Forward() override; + Result ForwardAsync(Event* event) override { return Status(eNotSupported); }; + + private: + void Build(std::unique_ptr& container, + zdl::DlSystem::Runtime_t runtime, zdl::DlSystem::RuntimeList runtimeList, + bool useUserSuppliedBuffers, zdl::DlSystem::PlatformConfig platformConfig); + + std::string ShapeStr(zdl::DlSystem::ITensor* pTensor); + + void copy_output(const zdl::DlSystem::ITensor* from, Tensor& to); + void copy_input(const Tensor& from, zdl::DlSystem::ITensor* to); + + Device device_; + Stream stream_; + std::vector input_tensors_; + std::vector output_tensors_; + + std::unique_ptr snpe_; + std::unique_ptr container_; + + std::vector> inputs_internal_; + zdl::DlSystem::TensorMap input_tensor_map_; +}; + +} // namespace mmdeploy + +#endif // MMDEPLOY_SRC_NET_SNPE_SNPE_NET_H_ diff --git a/demo/csrc/image_classification.cpp b/demo/csrc/image_classification.cpp index d600cb07a1..5e64581b9f 100644 --- a/demo/csrc/image_classification.cpp +++ b/demo/csrc/image_classification.cpp @@ -6,7 +6,7 @@ int main(int argc, char* argv[]) { if (argc != 4) { - fprintf(stderr, "usage:\n image_classification device_name model_path image_path\n"); + fprintf(stderr, "usage:\n image_classification device_name dump_model_directory image_path\n"); return 1; } auto device_name = argv[1]; diff --git a/docs/en/01-how-to-build/build_from_source.md b/docs/en/01-how-to-build/build_from_source.md index d20b8a4563..2aa6ecc851 100644 --- a/docs/en/01-how-to-build/build_from_source.md +++ b/docs/en/01-how-to-build/build_from_source.md @@ -37,3 +37,4 @@ Please visit the following links to find out how to build MMDeploy according to - [Windows](windows.md) - [Android-aarch64](android.md) - [NVIDIA Jetson](jetsons.md) +- [SNPE](snpe.md) diff --git a/docs/en/01-how-to-build/snpe.md b/docs/en/01-how-to-build/snpe.md new file mode 100644 index 0000000000..81aa5ca072 --- /dev/null +++ b/docs/en/01-how-to-build/snpe.md @@ -0,0 +1,194 @@ +# Build for SNPE + +It is quite simple to support snpe backend: Client/Server mode. + +this mode + +1. Can split `model convert` and `inference` environments; + +- Inference irrelevant matters are done on host +- We can get the real running results of gpu/npu instead of cpu simulation values + +2. Can cover cost-sensitive device, armv7/risc-v/mips chips meet product requirements, but often have limited support for Python; + +3. Can simplify mmdeploy installation steps. If you only want to convert snpe model and test, you don't need to compile the .whl package. + +## 1. Run inference server + +Download the prebuilt snpe inference server package, `adb push` it to the phone and execute. +Note that **the phone must have a qcom chip**. + +```bash +$ wget https://media.githubusercontent.com/media/tpoisonooo/mmdeploy_snpe_testdata/main/snpe-inference-server-1.59.tar.gz +... +$ sudo apt install adb +$ adb push snpe-inference-server-1.59.tar.gz /data/local/tmp/ + +# decompress and execute +$ adb shell +venus:/ $ cd /data/local/tmp +130|venus:/data/local/tmp $ tar xvf snpe-inference-server-1.59.tar.gz +... +130|venus:/data/local/tmp $ source export1.59.sh +130|venus:/data/local/tmp $ ./inference_server +... + Server listening on [::]:60000 +``` + +At this point the inference service should print all the ipv6 and ipv4 addresses of the device and listen on the port. + +tips: + +- If `adb devices` cannot find the device, may be: + - Some cheap cables can only charge and cannot transmit data + - or the "developer mode" of the phone is not turned on +- If you need to compile the binary by self, please refer to [NDK Cross Compiling snpe Inference Service](../appendix/cross_build_snpe_service.md) +- If a `segmentation fault` occurs when listening on a port, it may be because: + - The port number is already occupied, use another port + +## 2. Build mmdeploy + +### 1) Environment + +| Matters | Version | Remarks | +| ------- | ------------------ | ---------------------- | +| host OS | ubuntu18.04 x86_64 | snpe specified version | +| Python | **3.6.0** | snpe specified version | + +### 2) Installation + +Download [snpe-1.59 from the official website](https://developer.qualcomm.com/qfile/69652/snpe-1.59.0.zip) + +```bash +$ unzip snpe-1.59.0.zip +$ export SNPE_ROOT=${PWD}/snpe-1.59.0.3230 +$ cd /path/to/mmdeploy +$ export PYTHONPATH=${PWD}/service/snpe/client:${SNPE_ROOT}/lib/python:${PYTHONPATH} +$ export LD_LIBRARY_PATH=${SNPE_ROOT}/lib/x86_64-linux-clang:${LD_LIBRARY_PATH} +$ export PATH=${SNPE_ROOT}/bin/x86_64-linux-clang:${PATH} +$ python3 -m pip install -e . +``` + +## 3. Test the model + +Take Resnet-18 as an example. First refer to [documentation to install mmcls](https://github.com/open-mmlab/mmclassification) and use `tools/deploy.py` to convert the model. + +```bash +$ export MODEL_CONFIG=/path/to/mmclassification/configs/resnet/resnet18_8xb16_cifar10.py +$ export MODEL_PATH=https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth + +# Convert the model +$ cd /path/to/mmdeploy +$ python3 tools/deploy.py configs/mmcls/classification_snpe_static.py $MODEL_CONFIG $MODEL_PATH /path/to/test.png --work-dir resnet18 --device cpu --uri 10.0.0.1\:60000 --dump-info + +# Test +$ python3 tools/test.py configs/mmcls/classification_snpe_static.py $MODEL_CONFIG --model reset18/end2end.dlc --metrics accuracy precision f1_score recall --uri 10.0.0.1\:60000 +``` + +Note that `--uri` is required to specify the ip and port of the snpe inference service, ipv4 and ipv6 addresses can be used. + +## 4. Build SDK with Android SDK + +If you also need to compile mmdeploy SDK with Android NDK, please continue reading. + +### 1) Download NDK and OpenCV package and setup environment + +```bash +# Download android OCV +$ export OPENCV_VERSION=4.5.4 +$ wget https://github.com/opencv/opencv/releases/download/${OPENCV_VERSION}/opencv-${OPENCV_VERSION}-android-sdk.zip +$ unzip opencv-${OPENCV_VERSION}-android-sdk.zip + +$ export ANDROID_OCV_ROOT=`realpath opencv-${OPENCV_VERSION}-android-sdk` + +# Download ndk r23b +$ wget https://dl.google.com/android/repository/android-ndk-r23b-linux.zip +$ unzip android-ndk-r23b-linux.zip + +$ export ANDROID_NDK_ROOT=`realpath android-ndk-r23b` +``` + +### 2) Compile mmdeploy SDK + +```bash +$ cd /path/to/mmdeploy +$ mkdir build && cd build +$ cmake .. \ + -DMMDEPLOY_BUILD_SDK=ON -DMMDEPLOY_CODEBASES=all \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake \ + -DMMDEPLOY_CODEBASES=all -DMMDEPLOY_TARGET_BACKENDS=snpe \ + -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-30 \ + -DANDROID_STL=c++_static \ + -DOpenCV_DIR=${ANDROID_OCV_ROOT}/sdk/native/jni/abi-arm64-v8a \ + -DMMDEPLOY_SHARED_LIBS=ON + + $ make && make install +``` + +| Options | Description | +| ----------------------------- | ------------------------------------------------------------ | +| DMMDEPLOY_CODEBASES=all | Compile all algorithms' post-process | +| CMAKE_TOOLCHAIN_FILE | Load NDK parameters, mainly used to select compiler | +| MMDEPLOY_TARGET_BACKENDS=snpe | Inference backend | +| ANDROID_STL=c++\_static | In case of NDK environment can not find suitable c++ library | +| MMDEPLOY_SHARED_LIBS=ON | snpe does not provide static library | + +### 3) Compile demo + +```bash +$ cd /path/to/install/example +$ mkdir build && cd build + +$ cmake .. \ + -DMMDEPLOY_BUILD_SDK=ON -DMMDEPLOY_CODEBASES=all \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake \ + -DMMDEPLOY_CODEBASES=all -DMMDEPLOY_TARGET_BACKENDS=snpe \ + -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-30 \ + -DANDROID_STL=c++_static \ + -DOpenCV_DIR=${ANDROID_OCV_ROOT}/sdk/native/jni/abi-arm64-v8a \ + -DMMDEPLOY_SHARED_LIBS=ON \ + -DMMDeploy_DIR=${PWD}/../../lib/cmake/MMDeploy + +$ make +$ tree -L 1 +... +├── image_restorer +├── image_segmentation +├── object_detection +├── ocr +├── pose_detection +└── rotated_object_detection +``` + +Just `adb push` the binary file and .so to the device and execute. + +### 4) Run the demo + +First make sure that`--dump-info`is used during convert model, so that the `resnet18` directory has the files required by the SDK such as `pipeline.json`. + +`adb push` the model directory, executable file and .so to the device. + +```bash +$ cd /path/to/mmdeploy +$ adb push resnet18 /data/local/tmp +$ adb push tests/data/tiger.jpeg /data/local/tmp/resnet18/ + +$ cd /path/to/install/ +$ adb push lib /data/local/tmp + +$ cd /path/to/install/example/build +$ adb push image_classification /data/local/tmp/resnet18/ +``` + +Set up environment variable and execute the sample. + +```bash +$ adb push /path/to/mmcls/demo/demo.JPEG /data/local/tmp +$ adb shell +venus:/ $ cd /data/local/tmp/resnet18 +venus:/data/local/tmp/resnet18 $ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/data/local/tmp/lib + +venus:/data/local/tmp/resnet18 $ ./image_classification cpu ./ tiger.jpeg +.. +label: 3, score: 0.3214 +``` diff --git a/docs/en/03-benchmark/benchmark_edge.md b/docs/en/03-benchmark/benchmark_edge.md new file mode 100644 index 0000000000..5f9ec0782c --- /dev/null +++ b/docs/en/03-benchmark/benchmark_edge.md @@ -0,0 +1,57 @@ +# Test on embedded device + +Here are the test conclusions of our edge devices. You can directly obtain the results of your own environment with [model profiling](../02-how-to-run/how_to_evaluate_a_model.md). + +## Software and hardware environment + +- host OS ubuntu 18.04 +- backend SNPE-1.59 +- device Mi11 (qcom 888) + +## mmcls + +| model | dataset | spatial | fp32 top-1 (%) | snpe gpu hybrid fp32 top-1 (%) | latency (ms) | +| :------------------------------------------------------------------------------------------------------------------------------: | :---------: | :-----: | :------------: | :----------------------------: | :----------: | +| [ShuffleNetV2](https://github.com/open-mmlab/mmclassification/blob/master/configs/shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py) | ImageNet-1k | 224x224 | 69.55 | 69.83\* | 20±7 | +| [MobilenetV2](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py) | ImageNet-1k | 224x224 | 71.86 | 72.14\* | 15±6 | + +tips: + +1. The ImageNet-1k dataset is too large to test, only part of the dataset is used (8000/50000) +2. The heating of device will downgrade the frequency, so the time consumption will actually fluctuate. Here are the stable values after running for a period of time. This result is closer to the actual demand. + +## mmocr detection + +| model | dataset | spatial | fp32 hmean | snpe gpu hybrid hmean | latency(ms) | +| :---------------------------------------------------------------------------------------------------------------: | :-------: | :------: | :--------: | :-------------------: | :---------: | +| [PANet](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py) | ICDAR2015 | 1312x736 | 0.795 | 0.785 @thr=0.9 | 3100±100 | + +## mmpose + +| model | dataset | spatial | snpe hybrid AR@IoU=0.50 | snpe hybrid AP@IoU=0.50 | latency(ms) | +| :---------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------: | :-----: | :---------------------: | :---------------------: | :---------: | +| [pose_hrnet_w32](https://github.com/open-mmlab/mmpose/blob/master/configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap/animalpose/hrnet_w32_animalpose_256x256.py) | Animalpose | 256x256 | 0.997 | 0.989 | 630±50 | + +tips: + +- Test `pose_hrnet` using AnimalPose's test dataset instead of val dataset. + +## mmseg + +| model | dataset | spatial | mIoU | latency(ms) | +| :---------------------------------------------------------------------------------------------------------------: | :--------: | :------: | :---: | :---------: | +| [fcn](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fcn/fcn_r18-d8_512x1024_80k_cityscapes.py) | Cityscapes | 512x1024 | 71.11 | 4915±500 | + +tips: + +- `fcn` works fine with 512x1024 size. Cityscapes dataset uses 1024x2048 resolution which causes device to reboot. + +## Notes + +- We needs to manually split the mmdet model into two parts. Because + - In snpe source code, `onnx_to_ir.py` can only parse onnx input while `ir_to_dlc.py` does not support `topk` operator + - UDO (User Defined Operator) does not work with `snpe-onnx-to-dlc` +- mmedit model + - `srcnn` requires cubic resize which snpe does not support + - `esrgan` converts fine, but loading the model causes the device to reboot +- mmrotate depends on [e2cnn](https://pypi.org/project/e2cnn/) and needs to be installed manually [its Python3.6 compatible branch](https://github.com/QUVA-Lab/e2cnn) diff --git a/docs/en/appendix/cross_build_snpe_service.md b/docs/en/appendix/cross_build_snpe_service.md new file mode 100644 index 0000000000..f5aba17d87 --- /dev/null +++ b/docs/en/appendix/cross_build_snpe_service.md @@ -0,0 +1,166 @@ +# Cross compile snpe inference server on Ubuntu 18 + +mmdeploy has provided a prebuilt package, if you want to compile it by self, or need to modify the `.proto` file, you can refer to this document. + +Note that the official gRPC documentation does not have complete support for the NDK. + +## 1. Environment + +| Item | Version | Remarks | +| ------------------ | -------------- | --------------------------------------------------------- | +| snpe | 1.59 | 1.60 uses clang-8.0, which may cause compatibility issues | +| host OS | ubuntu18.04 | snpe1.59 specified version | +| NDK | r17c | snpe1.59 specified version | +| gRPC | commit 6f698b5 | - | +| Hardware equipment | qcom888 | qcom chip required | + +## 2. Cross compile gRPC with NDK + +1. Pull gRPC repo, compile `protoc` and `grpc_cpp_plugin` on host + +```bash +# Install dependencies +$ apt-get update && apt-get install -y libssl-dev +# Compile +$ git clone https://github.com/grpc/grpc --recursive=1 --depth=1 +$ mkdir -p cmake/build +$ pushd cmake/build + +$ cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DgRPC_INSTALL=ON \ + -DgRPC_BUILD_TESTS=OFF \ + -DgRPC_SSL_PROVIDER=package \ + ../.. +# Install to host +$ make -j +$ sudo make install +``` + +2. Download the NDK and cross-compile the static libraries with android aarch64 format + +```bash +$ wget https://dl.google.com/android/repository/android-ndk-r17c-linux-x86_64.zip +$ unzip android-ndk-r17c-linux-x86_64.zip + +$ export ANDROID_NDK=/path/to/android-ndk-r17c + +$ cd /path/to/grpc +$ mkdir -p cmake/build_aarch64 && pushd cmake/build_aarch64 + +$ cmake ../.. \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-26 \ + -DANDROID_TOOLCHAIN=clang \ + -DANDROID_STL=c++_shared \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/tmp/android_grpc_install_shared + +$ make -j +$ make install +``` + +3. At this point `/tmp/android_grpc_install` should have the complete installation file + +```bash +$ cd /tmp/android_grpc_install +$ tree -L 1 +. +├── bin +├── include +├── lib +└── share +``` + +## 3. \[Skipable\] Self-test whether NDK gRPC is available + +1. Compile the helloworld that comes with gRPC + +```bash +$ cd /path/to/grpc/examples/cpp/helloworld/ +$ mkdir cmake/build_aarch64 -p && pushd cmake/build_aarch64 + +$ cmake ../.. \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-26 \ + -DANDROID_STL=c++_shared \ + -DANDROID_TOOLCHAIN=clang \ + -DCMAKE_BUILD_TYPE=Release \ + -Dabsl_DIR=/tmp/android_grpc_install_shared/lib/cmake/absl \ + -DProtobuf_DIR=/tmp/android_grpc_install_shared/lib/cmake/protobuf \ + -DgRPC_DIR=/tmp/android_grpc_install_shared/lib/cmake/grpc + +$ make -j +$ ls greeter* +greeter_async_client greeter_async_server greeter_callback_server greeter_server +greeter_async_client2 greeter_callback_client greeter_client +``` + +2. Turn on debug mode on your phone, push the binary to `/data/local/tmp` + +```bash +$ adb push greeter* /data/local/tmp +``` + +3. `adb shell` into the phone, execute client/server + +```bash +/data/local/tmp $ ./greeter_client +Greeter received: Hello world +``` + +## 4. Cross compile snpe inference server + +1. Open the [snpe tools website](https://developer.qualcomm.com/software/qualcomm-neural-processing-sdk/tools) and download version 1.59. Unzip and set environment variables + +> Note that snpe >= 1.60 starts using `clang-8.0`, which may cause incompatibility with `libc++_shared.so` on older devices. + +```bash +$ export SNPE_ROOT=/path/to/snpe-1.59.0.3230 +``` + +2. Open the snpe server directory within mmdeploy, use the options when cross-compiling gRPC + +```bash +$ cd /path/to/mmdeploy +$ cd service/snpe/server + +$ mkdir -p build && cd build +$ export ANDROID_NDK=/path/to/android-ndk-r17c +$ cmake .. \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-26 \ + -DANDROID_STL=c++_shared \ + -DANDROID_TOOLCHAIN=clang \ + -DCMAKE_BUILD_TYPE=Release \ + -Dabsl_DIR=/tmp/android_grpc_install_shared/lib/cmake/absl \ + -DProtobuf_DIR=/tmp/android_grpc_install_shared/lib/cmake/protobuf \ + -DgRPC_DIR=/tmp/android_grpc_install_shared/lib/cmake/grpc + + $ make -j + $ file inference_server +inference_server: ELF 64-bit LSB shared object, ARM aarch64, version 1 (SYSV), dynamically linked, interpreter /system/bin/linker64, BuildID[sha1]=252aa04e2b982681603dacb74b571be2851176d2, with debug_info, not stripped +``` + +Finally, you can see `infernece_server`, `adb push` it to the device and execute. + +## 5. Regenerate the proto interface + +If you have changed `inference.proto`, you need to regenerate the .cpp and .py interfaces + +```Shell +$ python3 -m pip install grpc_tools --user +$ python3 -m grpc_tools.protoc -I./ --python_out=./client/ --grpc_python_out=./client/ inference.proto + +$ ln -s `which protoc-gen-grpc` +$ protoc --cpp_out=./ --grpc_out=./ --plugin=protoc-gen-grpc=grpc_cpp_plugin inference.proto +``` + +## Reference + +- snpe tutorial https://developer.qualcomm.com/sites/default/files/docs/snpe/cplus_plus_tutorial.html +- gRPC cross build script https://raw.githubusercontent.com/grpc/grpc/master/test/distrib/cpp/run_distrib_test_cmake_aarch64_cross.sh +- stackoverflow https://stackoverflow.com/questions/54052229/build-grpc-c-for-android-using-ndk-arm-linux-androideabi-clang-compiler diff --git a/docs/en/index.rst b/docs/en/index.rst index 015b7163da..0bc36a6425 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -31,6 +31,7 @@ You can switch between Chinese and English documents in the lower-left corner of 03-benchmark/supported_models.md 03-benchmark/benchmark.md + 03-benchmark/benchmark_edge.md .. toctree:: :maxdepth: 1 @@ -78,6 +79,11 @@ You can switch between Chinese and English documents in the lower-left corner of :maxdepth: 1 :caption: Tutorials on Model Deployment +.. toctree:: + :maxdepth: 1 + :caption: Appendix + + appendix/cross_build_snpe_service.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_cn/01-how-to-build/build_from_source.md b/docs/zh_cn/01-how-to-build/build_from_source.md index 5e18c9959c..66b6907d03 100644 --- a/docs/zh_cn/01-how-to-build/build_from_source.md +++ b/docs/zh_cn/01-how-to-build/build_from_source.md @@ -30,8 +30,6 @@ git clone -b master git@github.com:open-mmlab/mmdeploy.git --recursive git clone -b master https://github.com/open-mmlab/mmdeploy.git MMDeploy cd MMDeploy git submodule update --init --recursive - - ``` ## 编译 @@ -42,3 +40,4 @@ git clone -b master git@github.com:open-mmlab/mmdeploy.git --recursive - [Windows](windows.md) - [Android-aarch64](android.md) - [NVIDIA Jetson](jetsons.md) +- [Qcom SNPE](snpe.md) diff --git a/docs/zh_cn/01-how-to-build/snpe.md b/docs/zh_cn/01-how-to-build/snpe.md new file mode 100644 index 0000000000..8664c97be1 --- /dev/null +++ b/docs/zh_cn/01-how-to-build/snpe.md @@ -0,0 +1,198 @@ +# 支持 SNPE + +mmdeploy 集成 snpe 的方式简单且有效: Client/Server 模式。 + +这种模式 + +1. 能剥离`模型转换`和`推理`环境: + +- 推理无关事项在算力更高的设备上完成; +- 对于推理计算,能拿到 gpu/npu 真实运行结果,而非 cpu 模拟数值。 + +2. 能覆盖成本敏感的设备。armv7/risc-v/mips 芯片满足产品需求,但往往对 Python 支持有限; + +3. 能简化 mmdeploy 安装步骤。如果只想转 snpe 模型测试精度,不需要编译 .whl 包。 + +## 一、运行推理服务 + +下载预编译 snpe 推理服务包, `adb push` 到手机、执行。 +注意**手机要有 qcom 芯片**。 + +```bash +$ wget https://media.githubusercontent.com/media/tpoisonooo/mmdeploy_snpe_testdata/main/snpe-inference-server-1.59.tar.gz +... +$ sudo apt install adb +$ adb push snpe-inference-server-1.59.tar.gz /data/local/tmp/ + +# 解压运行 +$ adb shell +venus:/ $ cd /data/local/tmp +130|venus:/data/local/tmp $ tar xvf snpe-inference-server-1.59.tar.gz +... +130|venus:/data/local/tmp $ source export1.59.sh +130|venus:/data/local/tmp $ ./inference_server 60000 +... + Server listening on [::]:60000 +``` + +此时推理服务应打印设备所有 ipv6 和 ipv4 地址,并监听端口。 + +tips: + +- 如果 `adb devices` 找不到设备,可能因为: + - 有些廉价线只能充电、不能传输数据 + - 或者没有打开手机的“开发者模式” +- 如果需要自己编译,可参照 [NDK 交叉编译 snpe 推理服务](../appendix/cross_build_snpe_service.md) +- 如果监听端口时 `segmentation fault`,可能是因为: + - 端口号已占用,换一个端口 + +## 二、安装 mmdeploy + +1. 环境要求 + +| 事项 | 版本 | 备注 | +| ------- | ------------------ | ------------- | +| host OS | ubuntu18.04 x86_64 | snpe 指定版本 | +| Python | **3.6.0** | snpe 指定版本 | + +2. 安装 + +[官网下载 snpe-1.59](https://developer.qualcomm.com/qfile/69652/snpe-1.59.0.zip),解压设置环境变量 + +```bash +$ unzip snpe-1.59.0.zip +$ export SNPE_ROOT=${PWD}/snpe-1.59.0.3230 +$ cd /path/to/mmdeploy +$ export PYTHONPATH=${PWD}/service/snpe/client:${SNPE_ROOT}/lib/python:${PYTHONPATH} +$ export LD_LIBRARY_PATH=${SNPE_ROOT}/lib/x86_64-linux-clang:${LD_LIBRARY_PATH} +$ export PATH=${SNPE_ROOT}/bin/x86_64-linux-clang:${PATH} +$ python3 -m pip install -e . +``` + +tips: + +- 如果网络不好,[这个 .tar.gz](https://github.com/tpoisonooo/mmdeploy_snpe_testdata/blob/main/snpe-1.59.tar.gz) 仅减小官方包体积,没有修改原始内容。 + +## 三、测试模型 + +以 Resnet-18 为例。先参照[文档安装 mmcls](https://github.com/open-mmlab/mmclassification),然后使用 `tools/deploy.py` 转换模型。 + +```bash +$ export MODEL_CONFIG=/path/to/mmclassification/configs/resnet/resnet18_8xb16_cifar10.py +$ export MODEL_PATH=https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth + +# 模型转换 +$ cd /path/to/mmdeploy +$ python3 tools/deploy.py configs/mmcls/classification_snpe_static.py $MODEL_CONFIG $MODEL_PATH /path/to/test.png --work-dir resnet18 --device cpu --uri 192.168.1.1\:60000 --dump-info + +# 精度测试 +$ python3 tools/test.py configs/mmcls/classification_snpe_static.py $MODEL_CONFIG --model reset18/end2end.dlc --metrics accuracy precision f1_score recall --uri 192.168.1.1\:60000 +``` + +注意需要 `--uri` 指明 snpe 推理服务的 ip 和端口号,可以使用 ipv4 和 ipv6 地址。 + +## 四、Android NDK 编译 SDK + +如果你还需要用 Android NDK 编译 mmdeploy SDK,请继续阅读本章节。 + +### 1. 下载 OCV、NDK,设置环境变量 + +```bash +# 下载 android OCV +$ export OPENCV_VERSION=4.5.4 +$ wget https://github.com/opencv/opencv/releases/download/${OPENCV_VERSION}/opencv-${OPENCV_VERSION}-android-sdk.zip +$ unzip opencv-${OPENCV_VERSION}-android-sdk.zip + +$ export ANDROID_OCV_ROOT=`realpath opencv-${OPENCV_VERSION}-android-sdk` + +# 下载 ndk r23b +$ wget https://dl.google.com/android/repository/android-ndk-r23b-linux.zip +$ unzip android-ndk-r23b-linux.zip + +$ export ANDROID_NDK_ROOT=`realpath android-ndk-r23b` +``` + +### 2. 编译 mmdeploy SDK + +```bash +$ cd /path/to/mmdeploy +$ mkdir build && cd build +$ cmake .. \ + -DMMDEPLOY_BUILD_SDK=ON -DMMDEPLOY_CODEBASES=all \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake \ + -DMMDEPLOY_CODEBASES=all -DMMDEPLOY_TARGET_BACKENDS=snpe \ + -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-30 \ + -DANDROID_STL=c++_static \ + -DOpenCV_DIR=${ANDROID_OCV_ROOT}/sdk/native/jni/abi-arm64-v8a \ + -DMMDEPLOY_SHARED_LIBS=ON + + $ make && make install +``` + +选项说明 + +| 选项 | 说明 | +| ----------------------------- | ------------------------------------- | +| DMMDEPLOY_CODEBASES=all | 编译所有算法后处理 | +| CMAKE_TOOLCHAIN_FILE | 加载 NDK 参数,主要用于选择编译器版本 | +| MMDEPLOY_TARGET_BACKENDS=snpe | 使用 snpe 推理 | +| ANDROID_STL=c++\_static | 避免 NDK 环境找不到合适的 c++ lib | +| MMDEPLOY_SHARED_LIBS=ON | 官方 snpe 没有提供静态库 | + +### 3. 编译 demo + +```bash +$ cd /path/to/install/example +$ mkdir build && cd build + +$ cmake .. \ + -DMMDEPLOY_CODEBASES=all \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake \ + -DMMDEPLOY_CODEBASES=all -DMMDEPLOY_TARGET_BACKENDS=snpe \ + -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-30 \ + -DANDROID_STL=c++_static \ + -DOpenCV_DIR=${ANDROID_OCV_ROOT}/sdk/native/jni/abi-arm64-v8a \ + -DMMDeploy_DIR=${PWD}/../../lib/cmake/MMDeploy + +$ make +$ tree -L 1 +. +├── image_classification +├── image_restorer +├── image_segmentation +├── object_detection +├── ocr +├── pose_detection +└── rotated_object_detection +``` + +## 4. 运行 demo + +先确认测试模型用了 `--dump-info`,这样 `resnet18` 目录才有 `pipeline.json` 等 SDK 所需文件。 + +把 dump 好的模型目录、可执行文件和 lib 都 `adb push` 到设备里 + +```bash +$ cd /path/to/mmdeploy +$ adb push resnet18 /data/local/tmp +$ adb push tests/data/tiger.jpeg /data/local/tmp/resnet18/ + +$ cd /path/to/install/ +$ adb push lib /data/local/tmp + +$ cd /path/to/install/example/build +$ adb push image_classification /data/local/tmp/resnet18/ +``` + +设置环境变量,执行样例 + +```bash +$ adb push /path/to/mmcls/demo/demo.JPEG /data/local/tmp +$ adb shell +venus:/ $ cd /data/local/tmp/resnet18 +venus:/data/local/tmp/resnet18 $ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/data/local/tmp/lib + +venus:/data/local/tmp/resnet18 $ ./image_classification cpu ./ tiger.jpeg +.. +label: 3, score: 0.3214 +``` diff --git a/docs/zh_cn/03-benchmark/benchmark_edge.md b/docs/zh_cn/03-benchmark/benchmark_edge.md new file mode 100644 index 0000000000..d320c3fb7b --- /dev/null +++ b/docs/zh_cn/03-benchmark/benchmark_edge.md @@ -0,0 +1,58 @@ +# 边、端设备测试结果 + +这里给出我们边、端设备的测试结论,用户可以直接通过 [model profiling](../02-how-to-run/profile_model.md) 获得自己环境的结果。 + +## 软硬件环境 + +- host OS ubuntu 18.04 +- backend SNPE-1.59 +- device Mi11 (qcom 888) + +## mmcls 模型 + +| model | dataset | spatial | fp32 top-1 (%) | snpe gpu hybrid fp32 top-1 (%) | latency (ms) | +| :------------------------------------------------------------------------------------------------------------------------------: | :---------: | :-----: | :------------: | :----------------------------: | :----------: | +| [ShuffleNetV2](https://github.com/open-mmlab/mmclassification/blob/master/configs/shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py) | ImageNet-1k | 224x224 | 69.55 | 69.83\* | 20±7 | +| [MobilenetV2](https://github.com/open-mmlab/mmclassification/blob/master/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py) | ImageNet-1k | 224x224 | 71.86 | 72.14\* | 15±6 | + +tips: + +1. ImageNet-1k 数据集较大,仅使用一部分测试(8000/50000) +2. 边、端设备发热会降频,因此耗时实际上会波动。这里给出运行一段时间后、稳定的数值。这个结果更贴近实际需求 + +## mmocr 检测 + +| model | dataset | spatial | fp32 hmean | snpe gpu hybrid hmean | latency(ms) | +| :---------------------------------------------------------------------------------------------------------------: | :-------: | :------: | :--------: | :-------------------: | :---------: | +| [PANet](https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py) | ICDAR2015 | 1312x736 | 0.795 | 0.785 @thr=0.9 | 3100±100 | + +## mmpose 模型 + +| model | dataset | spatial | snpe hybrid AR@IoU=0.50 | snpe hybrid AP@IoU=0.50 | latency(ms) | +| :---------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------: | :-----: | :---------------------: | :---------------------: | :---------: | +| [pose_hrnet_w32](https://github.com/open-mmlab/mmpose/blob/master/configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap/animalpose/hrnet_w32_animalpose_256x256.py) | Animalpose | 256x256 | 0.997 | 0.989 | 630±50 | + +tips: + +- 测试 pose_hrnet 用的是 AnimalPose 的 test dataset,而非 val dataset + +## mmseg + +| model | dataset | spatial | mIoU | latency(ms) | +| :---------------------------------------------------------------------------------------------------------------: | :--------: | :------: | :---: | :---------: | +| [fcn](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fcn/fcn_r18-d8_512x1024_80k_cityscapes.py) | Cityscapes | 512x1024 | 71.11 | 4915±500 | + +tips: + +- fcn 用 512x1024 尺寸运行正常。Cityscapes 数据集 1024x2048 分辨率会导致设备重启 + +## 其他模型 + +- mmdet 需要手动把模型拆成两部分。因为 + - snpe 源码中 `onnx_to_ir.py` 仅能解析输入,`ir_to_dlc.py` 还不支持 topk + - UDO (用户自定义算子)无法和 `snpe-onnx-to-dlc` 配合使用 +- mmedit 模型 + - srcnn 需要 cubic resize,snpe 不支持 + - esrgan 可正常转换,但加载模型会导致设备重启 +- mmrotate 依赖 [e2cnn](https://pypi.org/project/e2cnn/) ,需要手动安装 [其 Python3.6 + 兼容分支](https://github.com/QUVA-Lab/e2cnn) diff --git a/docs/zh_cn/04-developer-guide/do_regression_test.md b/docs/zh_cn/04-developer-guide/do_regression_test.md index 45cd5d32d8..d1b33d0076 100644 --- a/docs/zh_cn/04-developer-guide/do_regression_test.md +++ b/docs/zh_cn/04-developer-guide/do_regression_test.md @@ -257,6 +257,7 @@ models: - [x] ncnn - [x] OpenVINO - [x] TorchScript +- [x] SNPE - [x] MMDeploy SDK ## 6. 支持的Codebase及其Metric diff --git a/docs/zh_cn/appendix/cross_build_snpe_service.md b/docs/zh_cn/appendix/cross_build_snpe_service.md new file mode 100644 index 0000000000..bb1ea4d40c --- /dev/null +++ b/docs/zh_cn/appendix/cross_build_snpe_service.md @@ -0,0 +1,170 @@ +# Ubuntu18.04 交叉编译 NDK snpe 推理服务 + +mmdeploy 已提供预编译包,如果你想自己编译、或需要对 .proto 接口做修改,可参考此文档。 + +注意 gRPC 官方文档并没有对 NDK 的完整支持。 + +## 一、环境说明 + +| 项目 | 版本 | 备注 | +| -------- | -------------- | ------------------------------------- | +| snpe | 1.59 | 1.60 使用 clang-8.0,可能导致兼容问题 | +| host OS | ubuntu18.04 | snpe1.59 指定版本 | +| NDK | r17c | snpe1.59 指定版本 | +| gRPC | commit 6f698b5 | - | +| 硬件设备 | qcom888 | 需要 qcom 芯片 | + +## 二、NDK 交叉编译 gRPC + +1. 拉取 gRPC repo, 在 host 上编译出 `protoc` 和 `grpc_cpp_plugin` + +```bash +# 安装依赖 +$ apt-get update && apt-get install -y libssl-dev +# 编译 +$ git clone https://github.com/grpc/grpc --recursive=1 --depth=1 +$ mkdir -p cmake/build +$ pushd cmake/build + +$ cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DgRPC_INSTALL=ON \ + -DgRPC_BUILD_TESTS=OFF \ + -DgRPC_SSL_PROVIDER=package \ + ../.. +# 需要安装到 host 环境 +$ make -j +$ sudo make install +``` + +2. 下载 NDK,交叉编译 android aarch64 所需静态库 + +```bash +$ wget https://dl.google.com/android/repository/android-ndk-r17c-linux-x86_64.zip +$ unzip android-ndk-r17c-linux-x86_64.zip + +# 设置环境变量 +$ export ANDROID_NDK=/path/to/android-ndk-r17c + +# 编译 +$ cd /path/to/grpc +$ mkdir -p cmake/build_aarch64 && pushd cmake/build_aarch64 + +$ cmake ../.. \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-26 \ + -DANDROID_TOOLCHAIN=clang \ + -DANDROID_STL=c++_shared \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/tmp/android_grpc_install_shared + +$ make -j +$ make install +``` + +3. 此时 `/tmp/android_grpc_install` 应有完整的安装文件 + +```bash +$ cd /tmp/android_grpc_install +$ tree -L 1 +. +├── bin +├── include +├── lib +└── share +``` + +## 三、【可跳过】自测 NDK gRPC 是否正常 + +1. 编译 gRPC 自带的 helloworld + +```bash +$ cd /path/to/grpc/examples/cpp/helloworld/ +$ mkdir cmake/build_aarch64 -p && pushd cmake/build_aarch64 + +$ cmake ../.. \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-26 \ + -DANDROID_STL=c++_shared \ + -DANDROID_TOOLCHAIN=clang \ + -DCMAKE_BUILD_TYPE=Release \ + -Dabsl_DIR=/tmp/android_grpc_install_shared/lib/cmake/absl \ + -DProtobuf_DIR=/tmp/android_grpc_install_shared/lib/cmake/protobuf \ + -DgRPC_DIR=/tmp/android_grpc_install_shared/lib/cmake/grpc + +$ make -j +$ ls greeter* +greeter_async_client greeter_async_server greeter_callback_server greeter_server +greeter_async_client2 greeter_callback_client greeter_client +``` + +2. 打开手机调试模式,push 编译结果到 `/data/local/tmp` 目录 + +tips:对于国产手机,设置 - 版本号,点击 7 次可进入开发者模式,然后才能打开 USB 调试 + +```bash +$ adb push greeter* /data/local/tmp +``` + +3. `adb shell` 进手机,执行 client/server + +```bash +/data/local/tmp $ ./greeter_client +Greeter received: Hello world +``` + +## 四、交叉编译 snpe 推理服务 + +1. 打开 [snpe tools 官网](https://developer.qualcomm.com/software/qualcomm-neural-processing-sdk/tools),下载 1.59 版本。 解压并设置环境变量 + +**注意 snpe >= 1.60 开始使用 `clang-8.0`,可能导致旧设备与 `libc++_shared.so` 不兼容。** + +```bash +$ export SNPE_ROOT=/path/to/snpe-1.59.0.3230 +``` + +2. 打开 mmdeploy snpe server 目录,使用交叉编译 gRPC 时的选项 + +```bash +$ cd /path/to/mmdeploy +$ cd service/snpe/server + +$ mkdir -p build && cd build +$ export ANDROID_NDK=/path/to/android-ndk-r17c +$ cmake .. \ + -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-26 \ + -DANDROID_STL=c++_shared \ + -DANDROID_TOOLCHAIN=clang \ + -DCMAKE_BUILD_TYPE=Release \ + -Dabsl_DIR=/tmp/android_grpc_install_shared/lib/cmake/absl \ + -DProtobuf_DIR=/tmp/android_grpc_install_shared/lib/cmake/protobuf \ + -DgRPC_DIR=/tmp/android_grpc_install_shared/lib/cmake/grpc + + $ make -j + $ file inference_server +inference_server: ELF 64-bit LSB shared object, ARM aarch64, version 1 (SYSV), dynamically linked, interpreter /system/bin/linker64, BuildID[sha1]=252aa04e2b982681603dacb74b571be2851176d2, with debug_info, not stripped +``` + +最终可得到 `infernece_server`,`adb push` 到设备上即可执行。 + +## 五、重新生成 proto 接口 + +如果改过 `inference.proto`,需要重新生成 .cpp 和 .py 通信接口 + +```Shell +$ python3 -m pip install grpc_tools --user +$ python3 -m grpc_tools.protoc -I./ --python_out=./client/ --grpc_python_out=./client/ inference.proto + +$ ln -s `which protoc-gen-grpc` +$ protoc --cpp_out=./ --grpc_out=./ --plugin=protoc-gen-grpc=grpc_cpp_plugin inference.proto +``` + +## 参考文档 + +- snpe tutorial https://developer.qualcomm.com/sites/default/files/docs/snpe/cplus_plus_tutorial.html +- gRPC cross build script https://raw.githubusercontent.com/grpc/grpc/master/test/distrib/cpp/run_distrib_test_cmake_aarch64_cross.sh +- stackoverflow https://stackoverflow.com/questions/54052229/build-grpc-c-for-android-using-ndk-arm-linux-androideabi-clang-compiler diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 0393d78f9c..b53cbf68e6 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -52,6 +52,12 @@ 05-tutorial/04_onnx_custom_op.md 05-tutorial/05_onnx_model_editing.md +.. toctree:: + :maxdepth: 1 + :caption: 附录 + + appendix/cross_build_snpe_service.md + .. toctree:: :maxdepth: 1 :caption: 常见问题 diff --git a/mmdeploy/apis/snpe/__init__.py b/mmdeploy/apis/snpe/__init__.py new file mode 100644 index 0000000000..6f8febaec3 --- /dev/null +++ b/mmdeploy/apis/snpe/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.backend.snpe import from_onnx as _from_onnx +from mmdeploy.backend.snpe import is_available +from ..core import PIPELINE_MANAGER + +from_onnx = PIPELINE_MANAGER.register_pipeline()(_from_onnx) + +__all__ = ['is_available', 'from_onnx'] + +if is_available(): + try: + from mmdeploy.backend.snpe.onnx2dlc import (get_env_key, + get_output_model_file) + __all__ += ['get_output_model_file', 'get_env_key'] + except Exception: + pass diff --git a/mmdeploy/apis/visualize.py b/mmdeploy/apis/visualize.py index 251880ed3e..be03944c1c 100644 --- a/mmdeploy/apis/visualize.py +++ b/mmdeploy/apis/visualize.py @@ -16,7 +16,8 @@ def visualize_model(model_cfg: Union[str, mmcv.Config], device: str, backend: Optional[Backend] = None, output_file: Optional[str] = None, - show_result: bool = False): + show_result: bool = False, + **kwargs): """Run inference with PyTorch or backend model and show results. Examples: @@ -64,7 +65,7 @@ def visualize_model(model_cfg: Union[str, mmcv.Config], if backend == Backend.PYTORCH: model = task_processor.init_pytorch_model(model[0]) else: - model = task_processor.init_backend_model(model) + model = task_processor.init_backend_model(model, **kwargs) model_inputs, _ = task_processor.create_input(img, input_shape) with torch.no_grad(): diff --git a/mmdeploy/backend/sdk/export_info.py b/mmdeploy/backend/sdk/export_info.py index 459e82b18b..8e68f413f2 100644 --- a/mmdeploy/backend/sdk/export_info.py +++ b/mmdeploy/backend/sdk/export_info.py @@ -127,6 +127,8 @@ def replace_suffix(file_name: str, dst_suffix: str) -> str: weights = replace_suffix(ir_name, '.bin') if 'precision' in deploy_cfg['backend_config']: precision = deploy_cfg['backend_config']['precision'] + elif backend == Backend.SNPE: + net = replace_suffix(ir_name, '.dlc') elif backend in [Backend.ONNXRUNTIME, Backend.TORCHSCRIPT]: pass else: diff --git a/mmdeploy/backend/snpe/__init__.py b/mmdeploy/backend/snpe/__init__.py new file mode 100644 index 0000000000..961b75dc7e --- /dev/null +++ b/mmdeploy/backend/snpe/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from .init_plugins import get_onnx2dlc_path +from .onnx2dlc import from_onnx + + +def is_available(): + """Check whether ncnn and snpe-onnx-to-dlc tool are installed. + + Returns: + bool: True if snpe-onnx-to-dlc tool are installed. + """ + + onnx2dlc = get_onnx2dlc_path() + if onnx2dlc is None: + return False + return osp.exists(onnx2dlc) + + +__all__ = ['from_onnx'] + +if is_available(): + try: + from .wrapper import SNPEWrapper + + __all__ += ['SNPEWrapper'] + except Exception as e: + print(e) + pass diff --git a/mmdeploy/backend/snpe/init_plugins.py b/mmdeploy/backend/snpe/init_plugins.py new file mode 100644 index 0000000000..7f4c35394d --- /dev/null +++ b/mmdeploy/backend/snpe/init_plugins.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import shutil + + +def get_onnx2dlc_path() -> str: + """Get snpe-onnx-to-dlc path. + + Returns: + str: A path of snpe-onnx-to-dlc tool. + """ + return shutil.which('snpe-onnx-to-dlc') diff --git a/mmdeploy/backend/snpe/onnx2dlc.py b/mmdeploy/backend/snpe/onnx2dlc.py new file mode 100644 index 0000000000..45e727e459 --- /dev/null +++ b/mmdeploy/backend/snpe/onnx2dlc.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import tempfile +from subprocess import call +from typing import List, Optional, Union + +import onnx + +from .init_plugins import get_onnx2dlc_path + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == '': + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def get_env_key() -> str: + """Return environment key str. + + Returns: + str: The string to find SNPE service URI + """ + return '__MMDEPLOY_SNPE_URI' + + +def get_output_model_file(onnx_path: str, + work_dir: Optional[str] = None) -> List[str]: + """Returns the path to the .dlc file with export result. + + Args: + onnx_path (str): The path to the onnx model. + work_dir (str|None): The path to the directory for saving the results. + Defaults to `None`, which means use the directory of onnx_path. + + Returns: + List[str]: The path to the files where the export result will be + located. + """ + if work_dir is None: + work_dir = osp.dirname(onnx_path) + mkdir_or_exist(osp.abspath(work_dir)) + file_name = osp.splitext(osp.split(onnx_path)[1])[0] + save_dlc = osp.join(work_dir, file_name + '.dlc') + return save_dlc + + +def from_onnx(onnx_model: Union[onnx.ModelProto, str], + output_file_prefix: str): + """Convert ONNX to dlc. + + We need to use a executable program to convert the `.onnx` file to a `.dlc` + + Example: + >>> from mmdeploy.apis.snpe import from_onnx + >>> onnx_path = 'work_dir/end2end.onnx' + >>> output_file_prefix = 'work_dir/end2end' + >>> from_onnx(onnx_path, output_file_prefix) + + Args: + onnx_path (ModelProto|str): The path of the onnx model. + output_file_prefix (str): The path to save the output .dlc file. + """ + + if not isinstance(onnx_model, str): + onnx_path = tempfile.NamedTemporaryFile(suffix='.onnx').name + onnx.save(onnx_model, onnx_path) + else: + onnx_path = onnx_model + + save_dlc = output_file_prefix + '.dlc' + + onnx2dlc = get_onnx2dlc_path() + ret_code = call( + [onnx2dlc, '--input_network', onnx_path, '--output', save_dlc]) + assert ret_code == 0, 'onnx2dlc failed' diff --git a/mmdeploy/backend/snpe/wrapper.py b/mmdeploy/backend/snpe/wrapper.py new file mode 100644 index 0000000000..f16d6a554b --- /dev/null +++ b/mmdeploy/backend/snpe/wrapper.py @@ -0,0 +1,250 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import abc +import os +import time +from random import randint +from typing import Dict, Optional, Sequence, Tuple + +import grpc +import inference_pb2 +import inference_pb2_grpc +import numpy as np +import torch + +from mmdeploy.backend.snpe.onnx2dlc import get_env_key +from mmdeploy.utils import Backend, get_root_logger +from mmdeploy.utils.timer import TimeCounter +from ..base import BACKEND_WRAPPER, BaseWrapper + + +# add interceptor to sleep and retry request +# https://github.com/grpc/grpc/issues/19514 +class SleepingPolicy(abc.ABC): + + @abc.abstractmethod + def sleep(self, try_i: int): + """How long to sleep in milliseconds. + + :param try_i: the number of retry (starting from zero) + """ + assert try_i >= 0 + + +class ExponentialBackoff(SleepingPolicy): + + def __init__(self, *, init_backoff_ms: int, max_backoff_ms: int, + multiplier: int): + self.init_backoff = randint(0, init_backoff_ms) + self.max_backoff = max_backoff_ms + self.multiplier = multiplier + + def sleep(self, try_i: int): + sleep_range = min(self.init_backoff * self.multiplier**try_i, + self.max_backoff) + sleep_ms = randint(0, sleep_range) + logger = get_root_logger() + logger.debug(f'Sleeping for {sleep_ms}') + time.sleep(sleep_ms / 1000) + + +class RetryOnRpcErrorClientInterceptor(grpc.UnaryUnaryClientInterceptor, + grpc.StreamUnaryClientInterceptor): + + def __init__( + self, + *, + max_attempts: int, + sleeping_policy: SleepingPolicy, + status_for_retry: Optional[Tuple[grpc.StatusCode]] = None, + ): + self.max_attempts = max_attempts + self.sleeping_policy = sleeping_policy + self.status_for_retry = status_for_retry + + def _intercept_call(self, continuation, client_call_details, + request_or_iterator): + + for try_i in range(self.max_attempts): + response = continuation(client_call_details, request_or_iterator) + + if isinstance(response, grpc.RpcError): + + # Return if it was last attempt + if try_i == (self.max_attempts - 1): + return response + + # If status code is not in retryable status codes + if (self.status_for_retry + and response.code() not in self.status_for_retry): + return response + + self.sleeping_policy.sleep(try_i) + else: + return response + + def intercept_unary_unary(self, continuation, client_call_details, + request): + return self._intercept_call(continuation, client_call_details, request) + + def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + return self._intercept_call(continuation, client_call_details, + request_iterator) + + +@BACKEND_WRAPPER.register_module(Backend.SNPE.value) +class SNPEWrapper(BaseWrapper): + """snpe wrapper class for inference. + + Args: + dlc_file (str): Path of a weight file. + output_names (Sequence[str] | None): Names of model outputs in order. + Defaults to `None` and the wrapper will load the output names from + snpe model. + + Examples: + >>> from mmdeploy.backend.snpe import SNPEWrapper + >>> import torch + >>> + >>> snple_file = 'alexnet.dlc' + >>> model = SNPEWrapper(snpe_file) + >>> inputs = dict(input=torch.randn(1, 3, 224, 224)) + >>> outputs = model(inputs) + >>> print(outputs) + """ + + def __init__(self, + dlc_file: str, + uri: str, + output_names: Optional[Sequence[str]] = None, + **kwargs): + + logger = get_root_logger() + + interceptors = (RetryOnRpcErrorClientInterceptor( + max_attempts=4, + sleeping_policy=ExponentialBackoff( + init_backoff_ms=100, max_backoff_ms=1600, multiplier=2), + status_for_retry=(grpc.StatusCode.UNAVAILABLE, ), + ), ) + + if uri is None and get_env_key() in os.environ: + logger.warn( + 'snpe remote service URI not set, search from environment') + uri = os.environ[get_env_key()] + + if uri is None: + logger.error('URI not set') + + weights = bytes() + filesize = os.stat(dlc_file).st_size + + logger.info(f'reading local model file {dlc_file}') + with open(dlc_file, 'rb') as f: + weights = f.read(filesize) + + self.stub = inference_pb2_grpc.InferenceStub( + grpc.intercept_channel(grpc.insecure_channel(uri), *interceptors)) + + logger.info('init remote SNPE engine with RPC, please wait...') + model = inference_pb2.Model(name=dlc_file, weights=weights, device=1) + resp = self.stub.Init(model) + + if resp.status != 0: + logger.error(f'init SNPE model failed {resp.info}') + return + + output = self.stub.OutputNames(inference_pb2.Empty()) + output_names = output.names + + super().__init__(output_names) + logger.info(f'init success, outputs {output_names}') + + def forward(self, inputs: Dict[str, + torch.Tensor]) -> Dict[str, torch.Tensor]: + """Run forward inference. + + Args: + inputs (Dict[str, torch.Tensor]): Key-value pairs of model inputs. + + Returns: + Dict[str, torch.Tensor]: Key-value pairs of model outputs. + """ + + def get_shape(shape): + if len(shape) == 4: + return (0, 2, 3, 1) + elif len(shape) == 3: + return (0, 1, 2) + elif len(shape) == 2: + return (0, 1) + return (0) + + input_list = list(inputs.values()) + device_type = input_list[0].device.type + + logger = get_root_logger() + + # build `list` inputs for remote snpe engine + snpe_inputs = [] + for name, input_tensor in inputs.items(): + data = input_tensor.contiguous().detach() + # snpe input layout is NHWC + data = data.permute(get_shape(data.shape)) + data = data.cpu().numpy() + + if data.dtype != np.float32: + logger.error('SNPE now only support fp32 input') + data = data.astype(dtype=np.float32) + tensor = inference_pb2.Tensor( + data=data.tobytes(), + name=name, + dtype='float32', + shape=list(data.shape)) + + snpe_inputs.append(tensor) + + return self.__snpe_execute( + tensorList=inference_pb2.TensorList(data=snpe_inputs), + device=device_type) + + @TimeCounter.count_time(Backend.SNPE.value) + def __snpe_execute(self, tensorList: inference_pb2.TensorList, + device: str) -> Dict[str, torch.tensor]: + """Run inference with snpe remote inference engine. + + Args: + tensorList (inference_pb2.TensorList): snpe input tensor. + + Returns: + dict[str, torch.tensor]: Inference results of snpe model. + """ + resp = self.stub.Inference(tensorList) + + def get_shape(shape): + if len(shape) == 4: + if shape[0] == 1 and shape[ + 1] == 1 and shape[2] > 1 and shape[3] > 1: + # snpe NHWC layout works except for segmentation task + return (0, 1, 2, 3) + return (0, 3, 1, 2) + elif len(shape) == 3: + return (0, 1, 2) + elif len(shape) == 2: + return (0, 1) + return (0) + + result = dict() + if resp.status == 0: + for tensor in resp.data: + ndarray = np.frombuffer(tensor.data, dtype=np.float32) + shape = tuple(tensor.shape) + data = torch.from_numpy( + ndarray.reshape(shape).copy()).to(device) + data = data.permute(get_shape(data.shape)) + result[tensor.name] = data + else: + logger = get_root_logger() + logger.error(f'snpe inference failed {resp.info}') + + return result diff --git a/mmdeploy/codebase/base/backend_model.py b/mmdeploy/codebase/base/backend_model.py index 50515e57bd..3a3ae3fafc 100644 --- a/mmdeploy/codebase/base/backend_model.py +++ b/mmdeploy/codebase/base/backend_model.py @@ -106,6 +106,13 @@ def _build_wrapper(backend: Backend, model=backend_files[0], input_names=input_names, output_names=output_names) + elif backend == Backend.SNPE: + from mmdeploy.backend.snpe import SNPEWrapper + uri = None + if 'uri' in kwargs: + uri = kwargs['uri'] + return SNPEWrapper( + dlc_file=backend_files[0], uri=uri, output_names=output_names) else: raise NotImplementedError(f'Unknown backend type: {backend.value}') diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index a51bea1ca8..d6d2c8333f 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -76,7 +76,11 @@ def init_backend_model(self, from .classification_model import build_classification_model model = build_classification_model( - model_files, self.model_cfg, self.deploy_cfg, device=self.device) + model_files, + self.model_cfg, + self.deploy_cfg, + device=self.device, + **kwargs) return model.eval() diff --git a/mmdeploy/codebase/mmcls/deploy/classification_model.py b/mmdeploy/codebase/mmcls/deploy/classification_model.py index bf6dcbca2a..915550d06e 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification_model.py +++ b/mmdeploy/codebase/mmcls/deploy/classification_model.py @@ -41,15 +41,19 @@ def __init__( device: str, class_names: Sequence[str], deploy_cfg: Union[str, mmcv.Config] = None, + **kwargs, ): super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg) self.CLASSES = class_names self.deploy_cfg = deploy_cfg self._init_wrapper( - backend=backend, backend_files=backend_files, device=device) + backend=backend, + backend_files=backend_files, + device=device, + **kwargs) def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], - device: str): + device: str, **kwargs): output_names = self.output_names self.wrapper = BaseBackendModel._build_wrapper( backend=backend, @@ -57,7 +61,8 @@ def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], device=device, input_names=[self.input_name], output_names=output_names, - deploy_cfg=self.deploy_cfg) + deploy_cfg=self.deploy_cfg, + **kwargs) def forward(self, img: List[torch.Tensor], *args, **kwargs) -> list: """Run forward inference. diff --git a/mmdeploy/codebase/mmedit/deploy/super_resolution.py b/mmdeploy/codebase/mmedit/deploy/super_resolution.py index 8d9140683d..477bb6a50b 100644 --- a/mmdeploy/codebase/mmedit/deploy/super_resolution.py +++ b/mmdeploy/codebase/mmedit/deploy/super_resolution.py @@ -89,7 +89,11 @@ def init_backend_model(self, """ from .super_resolution_model import build_super_resolution_model model = build_super_resolution_model( - model_files, self.model_cfg, self.deploy_cfg, device=self.device) + model_files, + self.model_cfg, + self.deploy_cfg, + device=self.device, + **kwargs) return model def init_pytorch_model(self, diff --git a/mmdeploy/codebase/mmedit/deploy/super_resolution_model.py b/mmdeploy/codebase/mmedit/deploy/super_resolution_model.py index 454de8b951..933390ce74 100644 --- a/mmdeploy/codebase/mmedit/deploy/super_resolution_model.py +++ b/mmdeploy/codebase/mmedit/deploy/super_resolution_model.py @@ -40,16 +40,20 @@ def __init__(self, backend_files: Sequence[str], device: str, model_cfg: mmcv.Config, - deploy_cfg: Union[str, mmcv.Config] = None): + deploy_cfg: Union[str, mmcv.Config] = None, + **kwargs): super().__init__(deploy_cfg=deploy_cfg) self.deploy_cfg = deploy_cfg self.test_cfg = model_cfg.test_cfg self.allowed_metrics = {'PSNR': psnr, 'SSIM': ssim} self._init_wrapper( - backend=backend, backend_files=backend_files, device=device) + backend=backend, + backend_files=backend_files, + device=device, + **kwargs) def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], - device: str): + device: str, **kwargs): output_names = self.output_names self.wrapper = BaseBackendModel._build_wrapper( backend=backend, @@ -57,7 +61,8 @@ def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], device=device, input_names=[self.input_name], output_names=output_names, - deploy_cfg=self.deploy_cfg) + deploy_cfg=self.deploy_cfg, + **kwargs) def forward(self, lq: torch.Tensor, @@ -231,8 +236,8 @@ def forward(self, def build_super_resolution_model(model_files: Sequence[str], model_cfg: Union[str, mmcv.Config], - deploy_cfg: Union[str, - mmcv.Config], device: str): + deploy_cfg: Union[str, mmcv.Config], + device: str, **kwargs): model_cfg = load_config(model_cfg)[0] deploy_cfg = load_config(deploy_cfg)[0] @@ -245,6 +250,7 @@ def build_super_resolution_model(model_files: Sequence[str], backend_files=model_files, device=device, model_cfg=model_cfg, - deploy_cfg=deploy_cfg) + deploy_cfg=deploy_cfg, + **kwargs) return backend_model diff --git a/mmdeploy/codebase/mmocr/deploy/text_detection.py b/mmdeploy/codebase/mmocr/deploy/text_detection.py index 051d4679b6..cbd8b155e9 100644 --- a/mmdeploy/codebase/mmocr/deploy/text_detection.py +++ b/mmdeploy/codebase/mmocr/deploy/text_detection.py @@ -76,7 +76,11 @@ def init_backend_model(self, """ from .text_detection_model import build_text_detection_model model = build_text_detection_model( - model_files, self.model_cfg, self.deploy_cfg, device=self.device) + model_files, + self.model_cfg, + self.deploy_cfg, + device=self.device, + **kwargs) return model.eval() def init_pytorch_model(self, diff --git a/mmdeploy/codebase/mmocr/deploy/text_detection_model.py b/mmdeploy/codebase/mmocr/deploy/text_detection_model.py index d6917161d9..9c42b70b0a 100644 --- a/mmdeploy/codebase/mmocr/deploy/text_detection_model.py +++ b/mmdeploy/codebase/mmocr/deploy/text_detection_model.py @@ -43,6 +43,7 @@ def __init__( device: str, deploy_cfg: Union[str, mmcv.Config] = None, model_cfg: Union[str, mmcv.Config] = None, + **kwargs, ): super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg) model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg) @@ -50,10 +51,13 @@ def __init__( self.show_score = False self.bbox_head = build_head(model_cfg.model.bbox_head) self._init_wrapper( - backend=backend, backend_files=backend_files, device=device) + backend=backend, + backend_files=backend_files, + device=device, + **kwargs) def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], - device: str): + device: str, **kwargs): """Initialize the wrapper of backends. Args: @@ -69,7 +73,8 @@ def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], device=device, input_names=[self.input_name], output_names=output_names, - deploy_cfg=self.deploy_cfg) + deploy_cfg=self.deploy_cfg, + **kwargs) def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[Sequence[dict]], *args, **kwargs) -> list: diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index fde76460be..a133167e51 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -101,7 +101,11 @@ def init_backend_model(self, """ from .pose_detection_model import build_pose_detection_model model = build_pose_detection_model( - model_files, self.model_cfg, self.deploy_cfg, device=self.device) + model_files, + self.model_cfg, + self.deploy_cfg, + device=self.device, + **kwargs) return model.eval() def init_pytorch_model(self, diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py index ddfb462d1b..ffc6f5a1c4 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py @@ -47,13 +47,16 @@ def __init__(self, self.deploy_cfg = deploy_cfg self.model_cfg = model_cfg self._init_wrapper( - backend=backend, backend_files=backend_files, device=device) + backend=backend, + backend_files=backend_files, + device=device, + **kwargs) # create base_head for decoding heatmap base_head = builder.build_head(model_cfg.model.keypoint_head) base_head.test_cfg = model_cfg.model.test_cfg self.base_head = base_head - def _init_wrapper(self, backend, backend_files, device): + def _init_wrapper(self, backend, backend_files, device, **kwargs): """Initialize backend wrapper. Args: @@ -69,7 +72,8 @@ def _init_wrapper(self, backend, backend_files, device): device=device, input_names=[self.input_name], output_names=output_names, - deploy_cfg=self.deploy_cfg) + deploy_cfg=self.deploy_cfg, + **kwargs) def forward(self, img: torch.Tensor, img_metas: Sequence[Sequence[dict]], *args, **kwargs): @@ -254,6 +258,7 @@ def build_pose_detection_model(model_files: Sequence[str], backend_files=model_files, device=device, model_cfg=model_cfg, - deploy_cfg=deploy_cfg) + deploy_cfg=deploy_cfg, + **kwargs) return backend_pose_model diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py index e23918a9d4..d3e57147b8 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py @@ -70,7 +70,11 @@ def init_backend_model(self, """ from .segmentation_model import build_segmentation_model model = build_segmentation_model( - model_files, self.model_cfg, self.deploy_cfg, device=self.device) + model_files, + self.model_cfg, + self.deploy_cfg, + device=self.device, + **kwargs) return model.eval() def init_pytorch_model(self, diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py index a57cb9a70b..8afe220cb7 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py @@ -37,23 +37,25 @@ class End2EndModel(BaseBackendModel): object. """ - def __init__( - self, - backend: Backend, - backend_files: Sequence[str], - device: str, - class_names: Sequence[str], - palette: np.ndarray, - deploy_cfg: Union[str, mmcv.Config] = None, - ): + def __init__(self, + backend: Backend, + backend_files: Sequence[str], + device: str, + class_names: Sequence[str], + palette: np.ndarray, + deploy_cfg: Union[str, mmcv.Config] = None, + **kwargs): super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg) self.CLASSES = class_names self.PALETTE = palette self.deploy_cfg = deploy_cfg self._init_wrapper( - backend=backend, backend_files=backend_files, device=device) + backend=backend, + backend_files=backend_files, + device=device, + **kwargs) - def _init_wrapper(self, backend, backend_files, device): + def _init_wrapper(self, backend, backend_files, device, **kwargs): output_names = self.output_names self.wrapper = BaseBackendModel._build_wrapper( backend=backend, @@ -61,7 +63,8 @@ def _init_wrapper(self, backend, backend_files, device): device=device, input_names=[self.input_name], output_names=output_names, - deploy_cfg=self.deploy_cfg) + deploy_cfg=self.deploy_cfg, + **kwargs) def forward(self, img: Sequence[torch.Tensor], img_metas: Sequence[Sequence[dict]], *args, **kwargs): diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index 370da96e43..56ba0859cb 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -55,6 +55,7 @@ class Backend(AdvancedEnum): ONNXRUNTIME = 'onnxruntime' PPLNN = 'pplnn' NCNN = 'ncnn' + SNPE = 'snpe' OPENVINO = 'openvino' SDK = 'sdk' TORCHSCRIPT = 'torchscript' diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 341865f468..706ce39a86 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,3 +1,4 @@ +grpcio h5py matplotlib multiprocess diff --git a/service/snpe/client/inference_pb2.py b/service/snpe/client/inference_pb2.py new file mode 100644 index 0000000000..a0072f039e --- /dev/null +++ b/service/snpe/client/inference_pb2.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: inference.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0finference.proto\x12\x08mmdeploy\"\x91\x01\n\x05Model\x12\x11\n\x04name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x0f\n\x07weights\x18\x02 \x01(\x0c\x12+\n\x06\x64\x65vice\x18\x03 \x01(\x0e\x32\x16.mmdeploy.Model.DeviceH\x01\x88\x01\x01\"#\n\x06\x44\x65vice\x12\x07\n\x03\x43PU\x10\x00\x12\x07\n\x03GPU\x10\x01\x12\x07\n\x03\x44SP\x10\x02\x42\x07\n\x05_nameB\t\n\x07_device\"\x07\n\x05\x45mpty\"Q\n\x06Tensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\x05\x64type\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\r\n\x05shape\x18\x04 \x03(\x05\x42\x08\n\x06_dtype\",\n\nTensorList\x12\x1e\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x10.mmdeploy.Tensor\"E\n\x05Reply\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0c\n\x04info\x18\x02 \x01(\t\x12\x1e\n\x04\x64\x61ta\x18\x03 \x03(\x0b\x32\x10.mmdeploy.Tensor\"\x16\n\x05Names\x12\r\n\x05names\x18\x01 \x03(\t2\xfb\x01\n\tInference\x12*\n\x04\x45\x63ho\x12\x0f.mmdeploy.Empty\x1a\x0f.mmdeploy.Reply\"\x00\x12*\n\x04Init\x12\x0f.mmdeploy.Model\x1a\x0f.mmdeploy.Reply\"\x00\x12\x31\n\x0bOutputNames\x12\x0f.mmdeploy.Empty\x1a\x0f.mmdeploy.Names\"\x00\x12\x34\n\tInference\x12\x14.mmdeploy.TensorList\x1a\x0f.mmdeploy.Reply\"\x00\x12-\n\x07\x44\x65stroy\x12\x0f.mmdeploy.Empty\x1a\x0f.mmdeploy.Reply\"\x00\x42%\n\rmmdeploy.snpeB\x0bSNPEWrapperP\x01\xa2\x02\x04SNPEb\x06proto3' +) + +_MODEL = DESCRIPTOR.message_types_by_name['Model'] +_EMPTY = DESCRIPTOR.message_types_by_name['Empty'] +_TENSOR = DESCRIPTOR.message_types_by_name['Tensor'] +_TENSORLIST = DESCRIPTOR.message_types_by_name['TensorList'] +_REPLY = DESCRIPTOR.message_types_by_name['Reply'] +_NAMES = DESCRIPTOR.message_types_by_name['Names'] +_MODEL_DEVICE = _MODEL.enum_types_by_name['Device'] +Model = _reflection.GeneratedProtocolMessageType( + 'Model', + (_message.Message, ), + { + 'DESCRIPTOR': _MODEL, + '__module__': 'inference_pb2' + # @@protoc_insertion_point(class_scope:mmdeploy.Model) + }) +_sym_db.RegisterMessage(Model) + +Empty = _reflection.GeneratedProtocolMessageType( + 'Empty', + (_message.Message, ), + { + 'DESCRIPTOR': _EMPTY, + '__module__': 'inference_pb2' + # @@protoc_insertion_point(class_scope:mmdeploy.Empty) + }) +_sym_db.RegisterMessage(Empty) + +Tensor = _reflection.GeneratedProtocolMessageType( + 'Tensor', + (_message.Message, ), + { + 'DESCRIPTOR': _TENSOR, + '__module__': 'inference_pb2' + # @@protoc_insertion_point(class_scope:mmdeploy.Tensor) + }) +_sym_db.RegisterMessage(Tensor) + +TensorList = _reflection.GeneratedProtocolMessageType( + 'TensorList', + (_message.Message, ), + { + 'DESCRIPTOR': _TENSORLIST, + '__module__': 'inference_pb2' + # @@protoc_insertion_point(class_scope:mmdeploy.TensorList) + }) +_sym_db.RegisterMessage(TensorList) + +Reply = _reflection.GeneratedProtocolMessageType( + 'Reply', + (_message.Message, ), + { + 'DESCRIPTOR': _REPLY, + '__module__': 'inference_pb2' + # @@protoc_insertion_point(class_scope:mmdeploy.Reply) + }) +_sym_db.RegisterMessage(Reply) + +Names = _reflection.GeneratedProtocolMessageType( + 'Names', + (_message.Message, ), + { + 'DESCRIPTOR': _NAMES, + '__module__': 'inference_pb2' + # @@protoc_insertion_point(class_scope:mmdeploy.Names) + }) +_sym_db.RegisterMessage(Names) + +_INFERENCE = DESCRIPTOR.services_by_name['Inference'] +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\rmmdeploy.snpeB\013SNPEWrapperP\001\242\002\004SNPE' + _MODEL._serialized_start = 30 + _MODEL._serialized_end = 175 + _MODEL_DEVICE._serialized_start = 120 + _MODEL_DEVICE._serialized_end = 155 + _EMPTY._serialized_start = 177 + _EMPTY._serialized_end = 184 + _TENSOR._serialized_start = 186 + _TENSOR._serialized_end = 267 + _TENSORLIST._serialized_start = 269 + _TENSORLIST._serialized_end = 313 + _REPLY._serialized_start = 315 + _REPLY._serialized_end = 384 + _NAMES._serialized_start = 386 + _NAMES._serialized_end = 408 + _INFERENCE._serialized_start = 411 + _INFERENCE._serialized_end = 662 +# @@protoc_insertion_point(module_scope) diff --git a/service/snpe/client/inference_pb2_grpc.py b/service/snpe/client/inference_pb2_grpc.py new file mode 100644 index 0000000000..a236900f86 --- /dev/null +++ b/service/snpe/client/inference_pb2_grpc.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import inference_pb2 as inference__pb2 + + +class InferenceStub(object): + """The inference service definition.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Echo = channel.unary_unary( + '/mmdeploy.Inference/Echo', + request_serializer=inference__pb2.Empty.SerializeToString, + response_deserializer=inference__pb2.Reply.FromString, + ) + self.Init = channel.unary_unary( + '/mmdeploy.Inference/Init', + request_serializer=inference__pb2.Model.SerializeToString, + response_deserializer=inference__pb2.Reply.FromString, + ) + self.OutputNames = channel.unary_unary( + '/mmdeploy.Inference/OutputNames', + request_serializer=inference__pb2.Empty.SerializeToString, + response_deserializer=inference__pb2.Names.FromString, + ) + self.Inference = channel.unary_unary( + '/mmdeploy.Inference/Inference', + request_serializer=inference__pb2.TensorList.SerializeToString, + response_deserializer=inference__pb2.Reply.FromString, + ) + self.Destroy = channel.unary_unary( + '/mmdeploy.Inference/Destroy', + request_serializer=inference__pb2.Empty.SerializeToString, + response_deserializer=inference__pb2.Reply.FromString, + ) + + +class InferenceServicer(object): + """The inference service definition.""" + + def Echo(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Init(self, request, context): + """Init Model with model file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def OutputNames(self, request, context): + """Get output names.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Inference(self, request, context): + """Inference with inputs.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Destroy(self, request, context): + """Destroy handle.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_InferenceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Echo': + grpc.unary_unary_rpc_method_handler( + servicer.Echo, + request_deserializer=inference__pb2.Empty.FromString, + response_serializer=inference__pb2.Reply.SerializeToString, + ), + 'Init': + grpc.unary_unary_rpc_method_handler( + servicer.Init, + request_deserializer=inference__pb2.Model.FromString, + response_serializer=inference__pb2.Reply.SerializeToString, + ), + 'OutputNames': + grpc.unary_unary_rpc_method_handler( + servicer.OutputNames, + request_deserializer=inference__pb2.Empty.FromString, + response_serializer=inference__pb2.Names.SerializeToString, + ), + 'Inference': + grpc.unary_unary_rpc_method_handler( + servicer.Inference, + request_deserializer=inference__pb2.TensorList.FromString, + response_serializer=inference__pb2.Reply.SerializeToString, + ), + 'Destroy': + grpc.unary_unary_rpc_method_handler( + servicer.Destroy, + request_deserializer=inference__pb2.Empty.FromString, + response_serializer=inference__pb2.Reply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'mmdeploy.Inference', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler, )) + + +# This class is part of an EXPERIMENTAL API. +class Inference(object): + """The inference service definition.""" + + @staticmethod + def Echo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, '/mmdeploy.Inference/Echo', + inference__pb2.Empty.SerializeToString, + inference__pb2.Reply.FromString, options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, + metadata) + + @staticmethod + def Init(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, '/mmdeploy.Inference/Init', + inference__pb2.Model.SerializeToString, + inference__pb2.Reply.FromString, options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, + metadata) + + @staticmethod + def OutputNames(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, '/mmdeploy.Inference/OutputNames', + inference__pb2.Empty.SerializeToString, + inference__pb2.Names.FromString, options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, + metadata) + + @staticmethod + def Inference(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, '/mmdeploy.Inference/Inference', + inference__pb2.TensorList.SerializeToString, + inference__pb2.Reply.FromString, options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, + metadata) + + @staticmethod + def Destroy(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, '/mmdeploy.Inference/Destroy', + inference__pb2.Empty.SerializeToString, + inference__pb2.Reply.FromString, options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, + metadata) diff --git a/service/snpe/inference.proto b/service/snpe/inference.proto new file mode 100644 index 0000000000..3505e74023 --- /dev/null +++ b/service/snpe/inference.proto @@ -0,0 +1,70 @@ +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "mmdeploy.snpe"; +option java_outer_classname = "SNPEWrapper"; +option objc_class_prefix = "SNPE"; + +package mmdeploy; + +// The inference service definition. +service Inference { + + rpc Echo(Empty) returns (Reply) {} + + // Init Model with model file + rpc Init(Model) returns (Reply) {} + + // Get output names + rpc OutputNames(Empty) returns (Names) {} + + // Inference with inputs + rpc Inference(TensorList) returns (Reply) {} + + // Destroy handle + rpc Destroy(Empty) returns (Reply) {} +} + +message Model { + optional string name = 1; + // bin + bytes weights = 2; + // config + enum Device { + CPU = 0; + GPU = 1; + DSP = 2; + } + optional Device device = 3; +} + +// https://stackoverflow.com/questions/31768665/can-i-define-a-grpc-call-with-a-null-request-or-response +message Empty {} + +message Tensor { + // name + string name = 1; + + // datatype + optional string dtype = 2; + + // data + bytes data = 3; + + // shape + repeated int32 shape = 4; +} + +message TensorList { + repeated Tensor data = 1; +} + +message Reply { + int32 status = 1; + string info = 2; + repeated Tensor data = 3; +} + +message Names { + repeated string names = 1; +} diff --git a/service/snpe/server/CMakeLists.txt b/service/snpe/server/CMakeLists.txt new file mode 100644 index 0000000000..f14ddc97a8 --- /dev/null +++ b/service/snpe/server/CMakeLists.txt @@ -0,0 +1,81 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cmake build file for C++ helloworld example. +# Assumes protobuf and gRPC have been installed using cmake. +# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build +# that automatically builds all the dependencies before building helloworld. + +cmake_minimum_required(VERSION 3.5.1) +project(SNPEServer C CXX) +include(./common.cmake) + +# Proto file +get_filename_component(hw_proto "../inference.proto" ABSOLUTE) +get_filename_component(hw_proto_path "${hw_proto}" PATH) + +# Generated sources +set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/inference.pb.cc") +set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/inference.pb.h") +set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/inference.grpc.pb.cc") +set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/inference.grpc.pb.h") + +add_custom_command( + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${hw_proto}" + DEPENDS "${hw_proto}") + +# Include generated *.pb.h files +include_directories("${CMAKE_CURRENT_BINARY_DIR}") + +# hw_grpc_proto +add_library(hw_grpc_proto + ${hw_grpc_srcs} + ${hw_grpc_hdrs} + ${hw_proto_srcs} + ${hw_proto_hdrs}) + +target_link_libraries(hw_grpc_proto + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF}) + +add_library(snpe SHARED IMPORTED) + +if (NOT EXISTS $ENV{SNPE_ROOT}/lib/aarch64-android-clang6.0/) + message(FATAL_ERROR "SNPE_ROOT directory not exist: "$ENV{SNPE_ROOT}/lib/aarch64-android-clang6.0/) +endif() + +set_target_properties(snpe PROPERTIES + IMPORTED_LOCATION "$ENV{SNPE_ROOT}/lib/aarch64-android-clang6.0/libSNPE.so" + INTERFACE_INCLUDE_DIRECTORIES "$ENV{SNPE_ROOT}/include/zdl" +) +target_link_directories( + snpe + INTERFACE +) + +add_executable(inference_server inference_server.cc service_impl.cpp) + +target_link_libraries(inference_server + hw_grpc_proto + ${_REFLECTION} + ${_GRPC_GRPCPP} + ${_PROTOBUF_LIBPROTOBUF} + snpe) diff --git a/service/snpe/server/common.cmake b/service/snpe/server/common.cmake new file mode 100644 index 0000000000..20d2f0c01e --- /dev/null +++ b/service/snpe/server/common.cmake @@ -0,0 +1,123 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cmake build file for C++ route_guide example. +# Assumes protobuf and gRPC have been installed using cmake. +# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build +# that automatically builds all the dependencies before building route_guide. + +cmake_minimum_required(VERSION 3.5.1) + +set (CMAKE_CXX_STANDARD 17) + +if(MSVC) + add_definitions(-D_WIN32_WINNT=0x600) +endif() + +find_package(Threads REQUIRED) + +if(GRPC_AS_SUBMODULE) + # One way to build a projects that uses gRPC is to just include the + # entire gRPC project tree via "add_subdirectory". + # This approach is very simple to use, but the are some potential + # disadvantages: + # * it includes gRPC's CMakeLists.txt directly into your build script + # without and that can make gRPC's internal setting interfere with your + # own build. + # * depending on what's installed on your system, the contents of submodules + # in gRPC's third_party/* might need to be available (and there might be + # additional prerequisites required to build them). Consider using + # the gRPC_*_PROVIDER options to fine-tune the expected behavior. + # + # A more robust approach to add dependency on gRPC is using + # cmake's ExternalProject_Add (see cmake_externalproject/CMakeLists.txt). + + # Include the gRPC's cmake build (normally grpc source code would live + # in a git submodule called "third_party/grpc", but this example lives in + # the same repository as gRPC sources, so we just look a few directories up) + add_subdirectory(../../.. ${CMAKE_CURRENT_BINARY_DIR}/grpc EXCLUDE_FROM_ALL) + message(STATUS "Using gRPC via add_subdirectory.") + + # After using add_subdirectory, we can now use the grpc targets directly from + # this build. + set(_PROTOBUF_LIBPROTOBUF libprotobuf) + set(_REFLECTION grpc++_reflection) + if(CMAKE_CROSSCOMPILING) + find_program(_PROTOBUF_PROTOC protoc) + else() + set(_PROTOBUF_PROTOC $) + endif() + set(_GRPC_GRPCPP grpc++) + if(CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) + else() + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) + endif() +elseif(GRPC_FETCHCONTENT) + # Another way is to use CMake's FetchContent module to clone gRPC at + # configure time. This makes gRPC's source code available to your project, + # similar to a git submodule. + message(STATUS "Using gRPC via add_subdirectory (FetchContent).") + include(FetchContent) + FetchContent_Declare( + grpc + GIT_REPOSITORY https://github.com/grpc/grpc.git + # when using gRPC, you will actually set this to an existing tag, such as + # v1.25.0, v1.26.0 etc.. + # For the purpose of testing, we override the tag used to the commit + # that's currently under test. + GIT_TAG vGRPC_TAG_VERSION_OF_YOUR_CHOICE) + FetchContent_MakeAvailable(grpc) + + # Since FetchContent uses add_subdirectory under the hood, we can use + # the grpc targets directly from this build. + set(_PROTOBUF_LIBPROTOBUF libprotobuf) + set(_REFLECTION grpc++_reflection) + set(_PROTOBUF_PROTOC $) + set(_GRPC_GRPCPP grpc++) + if(CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) + else() + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) + endif() +else() + # This branch assumes that gRPC and all its dependencies are already installed + # on this system, so they can be located by find_package(). + + # Find Protobuf installation + # Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. + set(protobuf_MODULE_COMPATIBLE TRUE) + find_package(Protobuf CONFIG REQUIRED) + message(STATUS "Using protobuf ${Protobuf_VERSION}") + + set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) + set(_REFLECTION gRPC::grpc++_reflection) + if(CMAKE_CROSSCOMPILING) + find_program(_PROTOBUF_PROTOC protoc) + else() + set(_PROTOBUF_PROTOC $) + endif() + + # Find gRPC installation + # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. + find_package(gRPC CONFIG REQUIRED) + message(STATUS "Using gRPC ${gRPC_VERSION}") + + set(_GRPC_GRPCPP gRPC::grpc++) + if(CMAKE_CROSSCOMPILING) + find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) + else() + set(_GRPC_CPP_PLUGIN_EXECUTABLE $) + endif() +endif() diff --git a/service/snpe/server/inference_server.cc b/service/snpe/server/inference_server.cc new file mode 100644 index 0000000000..9369cce48a --- /dev/null +++ b/service/snpe/server/inference_server.cc @@ -0,0 +1,109 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include + +#include "service_impl.h" +#include "text_table.h" + +void PrintIP() { + struct ifaddrs* ifAddrStruct = NULL; + void* tmpAddrPtr = NULL; + + int retval = getifaddrs(&ifAddrStruct); + if (retval == -1) { + return; + } + + helper::TextTable table("Device"); + table.padding(1); + table.add("port").add("ip").eor(); + while (ifAddrStruct != nullptr) { + if (ifAddrStruct->ifa_addr == nullptr) { + break; + } + + if (ifAddrStruct->ifa_addr->sa_family == AF_INET) { + tmpAddrPtr = &((struct sockaddr_in*)ifAddrStruct->ifa_addr)->sin_addr; + char addressBuffer[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, tmpAddrPtr, addressBuffer, INET_ADDRSTRLEN); + table.add(std::string(ifAddrStruct->ifa_name)).add(std::string(addressBuffer)).eor(); + } else if (ifAddrStruct->ifa_addr->sa_family == AF_INET6) { + tmpAddrPtr = &((struct sockaddr_in*)ifAddrStruct->ifa_addr)->sin_addr; + char addressBuffer[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, tmpAddrPtr, addressBuffer, INET6_ADDRSTRLEN); + table.add(std::string(ifAddrStruct->ifa_name)).add(std::string(addressBuffer)).eor(); + } + ifAddrStruct = ifAddrStruct->ifa_next; + } + std::cout << table << std::endl << std::endl; +} + +void RunServer(int port = 60000) { + // listen IPv4 and IPv6 + char server_address[64] = {0}; + sprintf(server_address, "[::]:%d", port); + InferenceServiceImpl service; + + grpc::EnableDefaultHealthCheckService(true); + grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + + // Max 128MB + builder.SetMaxMessageSize(2 << 29); + builder.SetMaxSendMessageSize(2 << 29); + + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + fprintf(stdout, "Server listening on %s\n", server_address); + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + int port = 60000; + if (argc > 1) { + port = std::stoi(argv[1]); + } + + if (port <= 9999) { + fprintf(stdout, "Usage: %s [port]\n", argv[0]); + return 0; + } + PrintIP(); + RunServer(port); + + return 0; +} diff --git a/service/snpe/server/scope_timer.h b/service/snpe/server/scope_timer.h new file mode 100644 index 0000000000..3730061257 --- /dev/null +++ b/service/snpe/server/scope_timer.h @@ -0,0 +1,34 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include + +#include +#include +#include + +class ScopeTimer { + public: + ScopeTimer(std::string _name, bool _print = false) : name(_name), print(_print) { begin = now(); } + + ~ScopeTimer() { + if (!print) { + return; + } + fprintf(stdout, "%s: %ldms\n", name.c_str(), (now() - begin)); + } + + long now() const { + struct timeval tv; + gettimeofday(&tv, NULL); + return tv.tv_sec * 1000 + (tv.tv_usec / 1000); + } + + long cost() const { return now() - begin; } + + private: + std::string name; + bool print; + long begin; +}; diff --git a/service/snpe/server/service_impl.cpp b/service/snpe/server/service_impl.cpp new file mode 100644 index 0000000000..6db484bd3e --- /dev/null +++ b/service/snpe/server/service_impl.cpp @@ -0,0 +1,358 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "service_impl.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "scope_timer.h" +#include "text_table.h" + +zdl::DlSystem::Runtime_t InferenceServiceImpl::CheckRuntime(zdl::DlSystem::Runtime_t runtime, + bool& staticQuantization) { + static zdl::DlSystem::Version_t Version = zdl::SNPE::SNPEFactory::getLibraryVersion(); + + fprintf(stdout, "SNPE Version: %s\n", Version.asString().c_str()); + + if ((runtime != zdl::DlSystem::Runtime_t::DSP) && staticQuantization) { + fprintf(stderr, + "ERROR: Cannot use static quantization with CPU/GPU runtimes. " + "It is only designed for DSP/AIP runtimes.\n" + "ERROR: Proceeding without static quantization on selected " + "runtime.\n"); + staticQuantization = false; + } + + if (!zdl::SNPE::SNPEFactory::isRuntimeAvailable(runtime)) { + fprintf(stderr, "Selected runtime not present. Falling back to CPU.\n"); + runtime = zdl::DlSystem::Runtime_t::CPU; + } + + return runtime; +} + +void InferenceServiceImpl::Build(std::unique_ptr& container, + zdl::DlSystem::Runtime_t runtime, + zdl::DlSystem::RuntimeList runtimeList, + bool useUserSuppliedBuffers, + zdl::DlSystem::PlatformConfig platformConfig) { + zdl::SNPE::SNPEBuilder snpeBuilder(container.get()); + + if (runtimeList.empty()) { + runtimeList.add(runtime); + } + + snpe = snpeBuilder.setOutputLayers({}) + .setRuntimeProcessorOrder(runtimeList) + .setUseUserSuppliedBuffers(useUserSuppliedBuffers) + .setPlatformConfig(platformConfig) + .setExecutionPriorityHint(zdl::DlSystem::ExecutionPriorityHint_t::HIGH) + .setPerformanceProfile(zdl::DlSystem::PerformanceProfile_t::SUSTAINED_HIGH_PERFORMANCE) + .build(); + return; +} + +void InferenceServiceImpl::SaveDLC(const ::mmdeploy::Model* request, const std::string& filename) { + auto model = request->weights(); + fprintf(stdout, "saving file to %s\n", filename.c_str()); + std::ofstream fout; + fout.open(filename, std::ios::binary | std::ios::out); + fout.write(model.data(), model.size()); + fout.flush(); + fout.close(); +} + +void InferenceServiceImpl::LoadFloatData(const std::string& data, std::vector& vec) { + size_t len = data.size(); + assert(len % sizeof(float) == 0); + const char* ptr = data.data(); + for (int i = 0; i < len; i += sizeof(float)) { + vec.push_back(*(float*)(ptr + i)); + } +} + +::grpc::Status InferenceServiceImpl::Echo(::grpc::ServerContext* context, + const ::mmdeploy::Empty* request, + ::mmdeploy::Reply* response) { + response->set_info("echo"); + return Status::OK; +} + +// Logic and data behind the server's behavior. +::grpc::Status InferenceServiceImpl::Init(::grpc::ServerContext* context, + const ::mmdeploy::Model* request, + ::mmdeploy::Reply* response) { + zdl::SNPE::SNPEFactory::initializeLogging(zdl::DlSystem::LogLevel_t::LOG_ERROR); + zdl::SNPE::SNPEFactory::setLogLevel(zdl::DlSystem::LogLevel_t::LOG_ERROR); + + if (snpe != nullptr) { + snpe.reset(); + } + if (container != nullptr) { + container.reset(); + } + + auto model = request->weights(); + container = + zdl::DlContainer::IDlContainer::open(reinterpret_cast(model.data()), model.size()); + if (container == nullptr) { + fprintf(stdout, "Stage Init: load dlc failed.\n"); + + response->set_status(-1); + response->set_info(zdl::DlSystem::getLastErrorString()); + return Status::OK; + } + fprintf(stdout, "Stage Init: load dlc success.\n"); + + zdl::DlSystem::Runtime_t runtime = zdl::DlSystem::Runtime_t::GPU; + if (request->has_device()) { + switch (request->device()) { + case mmdeploy::Model_Device_GPU: + runtime = zdl::DlSystem::Runtime_t::GPU; + break; + case mmdeploy::Model_Device_DSP: + runtime = zdl::DlSystem::Runtime_t::DSP; + default: + break; + } + } + + if (runtime != zdl::DlSystem::Runtime_t::CPU) { + bool static_quant = false; + runtime = CheckRuntime(runtime, static_quant); + } + + zdl::DlSystem::RuntimeList runtimeList; + runtimeList.add(zdl::DlSystem::Runtime_t::CPU); + runtimeList.add(runtime); + zdl::DlSystem::PlatformConfig platformConfig; + + { + ScopeTimer timer("build snpe"); + Build(container, runtime, runtimeList, false, platformConfig); + } + + if (snpe == nullptr) { + response->set_status(-1); + response->set_info(zdl::DlSystem::getLastErrorString()); + } + + // setup logger + auto logger_opt = snpe->getDiagLogInterface(); + if (!logger_opt) throw std::runtime_error("SNPE failed to obtain logging interface"); + auto logger = *logger_opt; + auto opts = logger->getOptions(); + static std::string OutputDir = "./output/"; + + opts.LogFileDirectory = OutputDir; + if (!logger->setOptions(opts)) { + std::cerr << "Failed to set options" << std::endl; + return Status::OK; + } + if (!logger->start()) { + std::cerr << "Failed to start logger" << std::endl; + return Status::OK; + } + + const auto& inputTensorNamesRef = snpe->getInputTensorNames(); + const auto& inputTensorNames = *inputTensorNamesRef; + + inputTensors.resize(inputTensorNames.size()); + for (int i = 0; i < inputTensorNames.size(); ++i) { + const char* pname = inputTensorNames.at(i); + const auto& shape_opt = snpe->getInputDimensions(pname); + const auto& shape = *shape_opt; + + fprintf(stdout, "Stage Init: input tensor info:\n"); + switch (shape.rank()) { + case 1: + fprintf(stdout, "name: %s, shape: [%ld]\n", pname, shape[0]); + break; + case 2: + fprintf(stdout, "name: %s, shape: [%ld,%ld]\n", pname, shape[0], shape[1]); + break; + case 3: + fprintf(stdout, "name: %s, shape: [%ld,%ld,%ld]\n", pname, shape[0], shape[1], shape[2]); + break; + case 4: + fprintf(stdout, "name: %s, shape: [%ld,%ld,%ld,%ld]\n", pname, shape[0], shape[1], shape[2], + shape[3]); + break; + } + inputTensors[i] = zdl::SNPE::SNPEFactory::getTensorFactory().createTensor(shape); + inputTensorMap.add(pname, inputTensors[i].get()); + } + + response->set_status(0); + response->set_info("Stage Init: success"); + return Status::OK; +} + +std::string InferenceServiceImpl::ContentStr(zdl::DlSystem::ITensor* pTensor) { + std::string str; + + const size_t N = std::min(5UL, pTensor->getSize()); + auto it = pTensor->cbegin(); + for (int i = 0; i < N; ++i) { + str += std::to_string(*(it + i)); + str += " "; + } + str += ".."; + str += std::to_string(*(it + pTensor->getSize() - 1)); + return str; +} + +std::string InferenceServiceImpl::ShapeStr(zdl::DlSystem::ITensor* pTensor) { + std::string str; + + str += "["; + auto shape = pTensor->getShape(); + for (int i = 0; i < shape.rank(); ++i) { + str += std::to_string(shape[i]); + str += ","; + } + str += ']'; + return str; +} + +::grpc::Status InferenceServiceImpl::OutputNames(::grpc::ServerContext* context, + const ::mmdeploy::Empty* request, + ::mmdeploy::Names* response) { + const auto& outputTensorNamesRef = snpe->getOutputTensorNames(); + const auto& outputTensorNames = *outputTensorNamesRef; + + for (int i = 0; i < outputTensorNames.size(); ++i) { + response->add_names(outputTensorNames.at(i)); + } + + return Status::OK; +} + +::grpc::Status InferenceServiceImpl::Inference(::grpc::ServerContext* context, + const ::mmdeploy::TensorList* request, + ::mmdeploy::Reply* response) { + // Get input names and number + const auto& inputTensorNamesRef = snpe->getInputTensorNames(); + + if (!inputTensorNamesRef) { + response->set_status(-1); + response->set_info(zdl::DlSystem::getLastErrorString()); + return Status::OK; + } + + const auto& inputTensorNames = *inputTensorNamesRef; + if (inputTensorNames.size() != request->data_size()) { + response->set_status(-1); + response->set_info("Stage Inference: input names count not match !"); + return Status::OK; + } + + helper::TextTable table("Inference"); + table.padding(1); + table.add("type").add("name").add("shape").add("content").eor(); + + // Load input/output buffers with TensorMap + { + // ScopeTimer timer("convert input"); + + for (int i = 0; i < request->data_size(); ++i) { + auto tensor = request->data(i); + std::vector float_input; + LoadFloatData(tensor.data(), float_input); + + zdl::DlSystem::ITensor* ptensor = inputTensorMap.getTensor(tensor.name().c_str()); + if (ptensor == nullptr) { + fprintf(stderr, "Stage Inference: name: %s not existed in input tensor map\n", + tensor.name().c_str()); + response->set_status(-1); + response->set_info("cannot find name in input tensor map."); + return Status::OK; + } + + if (float_input.size() != ptensor->getSize()) { + fprintf(stderr, "Stage Inference: input size not match, get %ld, expect %ld.\n", + float_input.size(), ptensor->getSize()); + response->set_status(-1); + response->set_info(zdl::DlSystem::getLastErrorString()); + return Status::OK; + } + + std::copy(float_input.begin(), float_input.end(), ptensor->begin()); + + table.add("IN").add(tensor.name()).add(ShapeStr(ptensor)).add(ContentStr(ptensor)).eor(); + } + } + + // A tensor map for SNPE execution outputs + zdl::DlSystem::TensorMap outputTensorMap; + // Execute the multiple input tensorMap on the model with SNPE + bool success = false; + { + ScopeTimer timer("execute", false); + success = snpe->execute(inputTensorMap, outputTensorMap); + + if (!success) { + response->set_status(-1); + response->set_info(zdl::DlSystem::getLastErrorString()); + return Status::OK; + } + + table.add("EXECUTE").add(std::to_string(timer.cost()) + "ms").eor(); + } + + { + // ScopeTimer timer("convert output"); + auto out_names = outputTensorMap.getTensorNames(); + for (size_t i = 0; i < out_names.size(); ++i) { + const char* name = out_names.at(i); + zdl::DlSystem::ITensor* ptensor = outputTensorMap.getTensor(name); + + table.add("OUT").add(std::string(name)).add(ShapeStr(ptensor)).add(ContentStr(ptensor)).eor(); + + const size_t data_length = ptensor->getSize(); + + std::string result; + result.resize(sizeof(float) * data_length); + int j = 0; + for (auto it = ptensor->cbegin(); it != ptensor->cend(); ++it, j += sizeof(float)) { + float f = *it; + memcpy(&result[0] + j, reinterpret_cast(&f), sizeof(float)); + } + + auto shape = ptensor->getShape(); + + ::mmdeploy::Tensor* pData = response->add_data(); + pData->set_dtype("float32"); + pData->set_name(name); + pData->set_data(result); + for (int j = 0; j < shape.rank(); ++j) { + pData->add_shape(shape[j]); + } + } + } + + std::cout << table << std::endl << std::endl; + + // build output status + response->set_status(0); + response->set_info("Stage Inference: success"); + return Status::OK; +} + +::grpc::Status InferenceServiceImpl::Destroy(::grpc::ServerContext* context, + const ::mmdeploy::Empty* request, + ::mmdeploy::Reply* response) { + snpe.reset(); + container.reset(); + inputTensors.clear(); + response->set_status(0); + zdl::SNPE::SNPEFactory::terminateLogging(); + return Status::OK; +} diff --git a/service/snpe/server/service_impl.h b/service/snpe/server/service_impl.h new file mode 100644 index 0000000000..c6b825fdb4 --- /dev/null +++ b/service/snpe/server/service_impl.h @@ -0,0 +1,78 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef SERVICE_IMPL_H +#define SERVICE_IMPL_H + +#include +#include +#include + +#include +#include +#include + +#include "DiagLog/IDiagLog.hpp" +#include "DlContainer/IDlContainer.hpp" +#include "DlSystem/DlEnums.hpp" +#include "DlSystem/DlError.hpp" +#include "DlSystem/ITensorFactory.hpp" +#include "DlSystem/IUserBuffer.hpp" +#include "DlSystem/PlatformConfig.hpp" +#include "DlSystem/RuntimeList.hpp" +#include "DlSystem/UserBufferMap.hpp" +#include "SNPE/SNPE.hpp" +#include "SNPE/SNPEBuilder.hpp" +#include "SNPE/SNPEFactory.hpp" +#include "inference.grpc.pb.h" + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; + +using mmdeploy::Empty; +using mmdeploy::Inference; +using mmdeploy::Model; +using mmdeploy::Reply; +using mmdeploy::Tensor; +using mmdeploy::TensorList; + +// Logic and data behind the server's behavior. +class InferenceServiceImpl final : public Inference::Service { + ::grpc::Status Echo(::grpc::ServerContext* context, const ::mmdeploy::Empty* request, + ::mmdeploy::Reply* response) override; + + // Init Model with model file + ::grpc::Status Init(::grpc::ServerContext* context, const ::mmdeploy::Model* request, + ::mmdeploy::Reply* response) override; + // Get output names + ::grpc::Status OutputNames(::grpc::ServerContext* context, const ::mmdeploy::Empty* request, + ::mmdeploy::Names* response) override; + // Inference with inputs + ::grpc::Status Inference(::grpc::ServerContext* context, const ::mmdeploy::TensorList* request, + ::mmdeploy::Reply* response) override; + // Destroy handle + ::grpc::Status Destroy(::grpc::ServerContext* context, const ::mmdeploy::Empty* request, + ::mmdeploy::Reply* response) override; + + void SaveDLC(const ::mmdeploy::Model* request, const std::string& name); + + void LoadFloatData(const std::string& data, std::vector& vec); + + zdl::DlSystem::Runtime_t CheckRuntime(zdl::DlSystem::Runtime_t runtime, bool& staticQuantization); + + void Build(std::unique_ptr& container, + zdl::DlSystem::Runtime_t runtime, zdl::DlSystem::RuntimeList runtimeList, + bool useUserSuppliedBuffers, zdl::DlSystem::PlatformConfig platformConfig); + + std::string ShapeStr(zdl::DlSystem::ITensor* pTensor); + + std::string ContentStr(zdl::DlSystem::ITensor* pTensor); + + std::unique_ptr snpe; + std::unique_ptr container; + std::vector> inputTensors; + zdl::DlSystem::TensorMap inputTensorMap; +}; + +#endif diff --git a/service/snpe/server/text_table.h b/service/snpe/server/text_table.h new file mode 100644 index 0000000000..39ea330880 --- /dev/null +++ b/service/snpe/server/text_table.h @@ -0,0 +1,209 @@ +/** + * \file sdk/load-and-run/src/text_table.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace helper { + +class TextTable { + public: + enum Level { Summary, Detail }; + enum class Align : int { Left, Right, Mid }; + TextTable() = default; + explicit TextTable(const std::string& table_name) : m_name(table_name) {} + TextTable& horizontal(char c) { + m_row.params.horizontal = c; + return *this; + } + TextTable& vertical(char c) { + m_row.params.vertical = c; + return *this; + } + TextTable& corner(char c) { + m_row.params.corner = c; + return *this; + } + TextTable& align(Align v) { + m_row.params.align = v; + return *this; + } + TextTable& padding(size_t w) { + m_padding = w; + return *this; + } + TextTable& prefix(const std::string& str) { + m_prefix = str; + return *this; + } + + template + TextTable& add(const T& value) { + m_row.values.emplace_back(value); + if (m_cols_max_w.size() < m_row.values.size()) { + m_cols_max_w.emplace_back(m_row.values.back().length()); + } else { + size_t i = m_row.values.size() - 1; + m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length()); + } + return *this; + } + + template ::value, bool>::type = 0> + TextTable& add(const T& value) { + std::stringstream ss; + ss << std::setiosflags(std::ios::fixed) << std::setprecision(2); + ss << value; + m_row.values.emplace_back(ss.str()); + if (m_cols_max_w.size() < m_row.values.size()) { + m_cols_max_w.emplace_back(m_row.values.back().length()); + } else { + size_t i = m_row.values.size() - 1; + m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length()); + } + return *this; + } + + template ::value, bool>::type = 0> + TextTable& add(const T& value) { + m_row.values.emplace_back(std::to_string(value)); + return *this; + } + + void eor() { + m_rows.emplace_back(m_row); + adjuster_last_row(); + m_row.values.clear(); + } + + void reset() { + m_row = {}; + m_cols_max_w.clear(); + m_padding = 0; + m_rows.clear(); + } + + void show(std::ostream& os) { + if (m_rows.empty()) return; + auto& last_row = m_rows.front(); + bool first = true; + for (auto& row : m_rows) { + auto& lrow = (last_row.values.size() * char_length(last_row.params.horizontal)) > + (row.values.size() * char_length(row.params.horizontal)) + ? last_row + : row; + // line before row + if (lrow.params.horizontal) { + if (not first) os << std::endl; + os << m_prefix; + if (lrow.params.corner) os << lrow.params.corner; + size_t skip_size = 0; + // table name + if (first) { + os << m_name; + skip_size = m_name.length(); + } + for (size_t i = 0; i < lrow.values.size(); ++i) { + auto max_w = m_cols_max_w.at(i) + m_padding * 2; + if (max_w + char_length(lrow.params.corner) <= skip_size) { + skip_size = skip_size - max_w - char_length(lrow.params.corner); + continue; + } + size_t rest = max_w + char_length(lrow.params.corner) - skip_size; + skip_size = 0; + if (rest > char_length(lrow.params.corner)) { + os << std::string(rest - char_length(lrow.params.corner), lrow.params.horizontal); + rest = char_length(lrow.params.corner); + } + if (rest > 0 && lrow.params.corner) os << lrow.params.corner; + } + } else if (first) { + os << m_prefix << ' ' << m_name; + } + first = false; + os << std::endl << m_prefix; + if (row.params.vertical) os << row.params.vertical; + // row + for (size_t i = 0; i < row.values.size(); ++i) { + auto& str = row.values.at(i); + auto max_w = m_cols_max_w.at(i) + 2 * m_padding; + if (row.params.align == Align::Mid) { + mid(os, str, max_w); + } else if (row.params.align == Align::Left) { + os << std::setw(max_w) << std::left << str; + } else { + os << std::setw(max_w) << std::right << str; + } + if (row.params.vertical) os << row.params.vertical; + } + last_row = row; + } + if (last_row.params.horizontal) { + os << std::endl << m_prefix; + if (last_row.params.corner) os << last_row.params.corner; + for (size_t i = 0; i < last_row.values.size(); ++i) { + auto max_w = m_cols_max_w.at(i); + std::string tmp(max_w + m_padding * 2, last_row.params.horizontal); + os << tmp; + if (last_row.params.corner) os << last_row.params.corner; + } + } + } + + private: + void adjuster_last_row() { + if (m_rows.empty()) return; + auto& row = m_rows.back(); + if (row.params.horizontal == 0 or row.params.vertical == 0) { + row.params.corner = 0; + } + if (row.params.horizontal != 0 && row.params.vertical != 0 && row.params.corner == 0) { + row.params.corner = row.params.horizontal; + } + } + + inline void mid(std::ostream& os, const std::string& str, size_t max_w) { + size_t l = (max_w - str.length()) / 2 + str.length(); + size_t r = max_w - l; + os << std::setw(l) << std::right << str; + if (r > 0) os << std::setw(r) << ' '; + } + inline size_t char_length(char c) { return c ? 1 : 0; } + std::string m_name; + std::vector m_cols_max_w; + size_t m_padding = 0; + std::string m_prefix = ""; + struct Row { + std::vector values; + struct Params { + Align align = Align::Left; + char horizontal = '-', vertical = '|', corner = '+'; + } params; + }; + std::vector m_rows; + Row m_row; +}; + +inline std::ostream& operator<<(std::ostream& stream, TextTable& table) { + table.show(stream); + return stream; +} + +} // namespace helper diff --git a/setup.cfg b/setup.cfg index 2073768695..b02db3a5c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,3 +14,4 @@ known_first_party = mmdeploy known_third_party = h5py,m2r,mmcls,mmcv,mmdeploy_python,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY +skip = service/snpe/client/inference_pb2.py,service/snpe/client/inference_pb2_grpc.py diff --git a/tools/check_env.py b/tools/check_env.py index 20b9907c85..e25806fa57 100644 --- a/tools/check_env.py +++ b/tools/check_env.py @@ -41,6 +41,9 @@ def check_backend(): import mmdeploy.apis.openvino as openvino_apis logger.info(f'openvino_is_avaliable: {openvino_apis.is_available()}') + import mmdeploy.apis.snpe as snpe_apis + logger.info(f'snpe_is_available: {snpe_apis.is_available()}') + def check_codebase(): codebase_versions = get_codebase_version() diff --git a/tools/deploy.py b/tools/deploy.py index 537c7aeec2..0452cec23f 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -54,6 +54,10 @@ def parse_args(): help='Image directory for quantize model.') parser.add_argument( '--quant', action='store_true', help='Quantize model to low bit.') + parser.add_argument( + '--uri', + default='192.168.1.1:60000', + help='Remote ipv4:port or ipv6:port for inference on edge device.') args = parser.parse_args() return args @@ -266,6 +270,30 @@ def main(): else: backend_files += [model_param_path, model_bin_path] + elif backend == Backend.SNPE: + from mmdeploy.apis.snpe import is_available as is_available + + if not is_available(): + logger.error('snpe support is not available, please check \ + 1) `snpe-onnx-to-dlc` existed in `PATH` 2) snpe only support \ + ubuntu18.04') + exit(1) + + import mmdeploy.apis.snpe as snpe_api + from mmdeploy.apis.snpe import get_env_key, get_output_model_file + + if get_env_key() not in os.environ: + os.environ[get_env_key()] = args.uri + + PIPELINE_MANAGER.set_log_level(log_level, [snpe_api.from_onnx]) + + backend_files = [] + for onnx_path in ir_files: + dlc_path = get_output_model_file(onnx_path, args.work_dir) + onnx_name = osp.splitext(osp.split(onnx_path)[1])[0] + snpe_api.from_onnx(onnx_path, osp.join(args.work_dir, onnx_name)) + backend_files = [dlc_path] + elif backend == Backend.OPENVINO: from mmdeploy.apis.openvino import \ is_available as is_available_openvino @@ -331,17 +359,19 @@ def main(): # for headless installation. if not headless: - # visualize model of the backend + extra = dict( + backend=backend, + output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'), + show_result=args.show) + if backend == Backend.SNPE: + extra['uri'] = args.uri + create_process( f'visualize {backend.value} model', target=visualize_model, args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img, args.device), - kwargs=dict( - backend=backend, - output_file=osp.join(args.work_dir, - f'output_{backend.value}.jpg'), - show_result=args.show), + kwargs=extra, ret_value=ret_value) # visualize pytorch model diff --git a/tools/onnx2dlc.py b/tools/onnx2dlc.py new file mode 100644 index 0000000000..526a9c7f7d --- /dev/null +++ b/tools/onnx2dlc.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import logging + +from mmdeploy.apis.snpe import from_onnx +from mmdeploy.utils import get_root_logger + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert ONNX to snpe dlc format.') + parser.add_argument('onnx_path', help='ONNX model path') + parser.add_argument('output_prefix', help='output snpe dlc model path') + parser.add_argument( + '--log-level', + help='set log level', + default='INFO', + choices=list(logging._nameToLevel.keys())) + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + logger = get_root_logger(log_level=args.log_level) + + onnx_path = args.onnx_path + output_prefix = args.output_prefix + + logger.info(f'onnx2dlc: \n\tonnx_path: {onnx_path} ') + from_onnx(onnx_path, output_prefix) + logger.info('onnx2dlc success.') + + +if __name__ == '__main__': + main() diff --git a/tools/test.py b/tools/test.py index 50ae79ca44..04ffcffca7 100644 --- a/tools/test.py +++ b/tools/test.py @@ -73,6 +73,11 @@ def parse_args(): help='the interval between each log, require setting ' 'speed-test first', default=100) + parser.add_argument( + '--uri', + action='store_true', + default='192.168.1.1:60000', + help='Remote ipv4:port or ipv6:port for inference on edge device.') args = parser.parse_args() return args @@ -103,7 +108,7 @@ def main(): workers_per_gpu=model_cfg.data.workers_per_gpu) # load the model of the backend - model = task_processor.init_backend_model(args.model) + model = task_processor.init_backend_model(args.model, uri=args.uri) is_device_cpu = (args.device == 'cpu') device_id = None if is_device_cpu else parse_device_id(args.device)