forked from open-mmlab/mmdetection3d
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Python wrapper for SDK (open-mmlab#27)
* add python API for detector * integrate detection * add python segmentor * add segmentation support * add classifier, text-detector, text-recognizer and restorer * integrate classifier * integrate textdet, textrecog and restorer * simplify * add inst-seg * fix inst-seg * integrate inst-seg * Moidfy _build_wrapper * better pipeline substitution * use registry for backend model creation * build Python module according to C API targets * minor fix * move sdk data pipeline to backend_config * remove debugging lines * add docstring for SDKEnd2EndModel * fix type hint * fix lint * fix lint * insert build/lib to sys.path Co-authored-by: SingleZombie <[email protected]>
- Loading branch information
1 parent
12ee956
commit bb655af
Showing
35 changed files
with
962 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
[settings] | ||
known_third_party = h5py,m2r,mmcls,mmcv,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision | ||
known_third_party = h5py,m2r,mmcls,mmcv,mmdeploy_python,mmdet,mmedit,mmocr,mmseg,ncnn,numpy,onnx,onnxruntime,packaging,pyppeteer,pyppl,pytest,pytorch_sphinx_theme,recommonmark,setuptools,sphinx,tensorrt,torch,torchvision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
backend_config = dict(type='sdk') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./classification_dynamic.py', '../_base_/backends/sdk.py'] | ||
|
||
codebase_config = dict(model_type='sdk') | ||
|
||
backend_config = dict(pipeline=[ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape']) | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/sdk.py'] | ||
|
||
codebase_config = dict(model_type='sdk') | ||
|
||
backend_config = dict(pipeline=[ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape']) | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/sdk.py'] | ||
|
||
codebase_config = dict(model_type='sdk', has_mask=True) | ||
|
||
backend_config = dict(pipeline=[ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape']) | ||
]) |
3 changes: 3 additions & 0 deletions
3
configs/mmedit/super-resolution/super-resolution_sdk_dynamic.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/sdk.py'] | ||
|
||
codebase_config = dict(model_type='sdk') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./text-detection_dynamic.py', '../../_base_/backends/sdk.py'] | ||
|
||
codebase_config = dict(model_type='sdk') | ||
|
||
backend_config = dict(pipeline=[ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape']) | ||
]) |
8 changes: 8 additions & 0 deletions
8
configs/mmocr/text-recognition/text-recognition_sdk_dynamic.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./text-recognition_dynamic.py', '../../_base_/backends/sdk.py'] | ||
|
||
codebase_config = dict(model_type='sdk') | ||
|
||
backend_config = dict(pipeline=[ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape']) | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./segmentation_dynamic.py', '../_base_/backends/sdk.py'] | ||
|
||
codebase_config = dict(model_type='sdk') | ||
|
||
backend_config = dict(pipeline=[ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape']) | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
cmake_minimum_required(VERSION 3.14) | ||
project(mmdeploy_python) | ||
|
||
if (NOT TARGET pybind11) | ||
add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11) | ||
endif () | ||
|
||
set(MMDEPLOY_PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp) | ||
|
||
macro(mmdeploy_python_add_module name) | ||
if (TARGET mmdeploy_${name}) | ||
list(APPEND MMDEPLOY_PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/${name}.cpp) | ||
endif () | ||
endmacro() | ||
|
||
mmdeploy_python_add_module(classifier) | ||
mmdeploy_python_add_module(detector) | ||
mmdeploy_python_add_module(segmentor) | ||
mmdeploy_python_add_module(text_detector) | ||
mmdeploy_python_add_module(text_recognizer) | ||
mmdeploy_python_add_module(restorer) | ||
|
||
pybind11_add_module(${PROJECT_NAME} ${MMDEPLOY_PYTHON_SRCS}) | ||
|
||
target_link_libraries(${PROJECT_NAME} PRIVATE | ||
${MMDEPLOY_LIBS} | ||
-Wl,--whole-archive ${MMDEPLOY_STATIC_MODULES} -Wl,--no-whole-archive | ||
-Wl,--no-as-needed ${MMDEPLOY_DYNAMIC_MODULES} -Wl,--as-need) | ||
|
||
target_include_directories(${PROJECT_NAME} PRIVATE | ||
${CMAKE_CURRENT_SOURCE_DIR} | ||
${CMAKE_CURRENT_SOURCE_DIR}/../..) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
#include "classifier.h" | ||
|
||
#include "common.h" | ||
|
||
namespace mmdeploy { | ||
|
||
class PyClassifier { | ||
public: | ||
PyClassifier(const char *model_path, const char *device_name, int device_id) { | ||
auto status = mmdeploy_classifier_create_by_path(model_path, device_name, device_id, &handle_); | ||
if (status != MM_SUCCESS) { | ||
throw std::runtime_error("failed to create classifier"); | ||
} | ||
} | ||
~PyClassifier() { | ||
mmdeploy_classifier_destroy(handle_); | ||
handle_ = {}; | ||
} | ||
|
||
// std::vector<py::array_t<float>> | ||
std::vector<std::vector<std::tuple<int, float>>> Apply(const std::vector<PyImage> &imgs) { | ||
std::vector<mm_mat_t> mats; | ||
mats.reserve(imgs.size()); | ||
for (const auto &img : imgs) { | ||
auto mat = GetMat(img); | ||
mats.push_back(mat); | ||
} | ||
mm_class_t *results{}; | ||
int *result_count{}; | ||
auto status = | ||
mmdeploy_classifier_apply(handle_, mats.data(), (int)mats.size(), &results, &result_count); | ||
if (status != MM_SUCCESS) { | ||
throw std::runtime_error("failed to apply classifier, code: " + std::to_string(status)); | ||
} | ||
auto output = std::vector<std::vector<std::tuple<int, float>>>{}; | ||
output.reserve(mats.size()); | ||
auto result_ptr = results; | ||
for (int i = 0; i < mats.size(); ++i) { | ||
std::vector<std::tuple<int, float>> label_score; | ||
for (int j = 0; j < result_count[i]; ++j) { | ||
label_score.emplace_back(result_ptr[j].label_id, result_ptr[j].score); | ||
} | ||
output.push_back(std::move(label_score)); | ||
result_ptr += result_count[i]; | ||
} | ||
mmdeploy_classifier_release_result(results, result_count, (int)mats.size()); | ||
return output; | ||
} | ||
|
||
private: | ||
mm_handle_t handle_{}; | ||
}; | ||
|
||
static void register_python_classifier(py::module &m) { | ||
py::class_<PyClassifier>(m, "Classifier") | ||
.def(py::init([](const char *model_path, const char *device_name, int device_id) { | ||
return std::make_unique<PyClassifier>(model_path, device_name, device_id); | ||
})) | ||
.def("__call__", &PyClassifier::Apply); | ||
} | ||
|
||
class PythonClassifierRegisterer { | ||
public: | ||
PythonClassifierRegisterer() { | ||
gPythonBindings().emplace("classifier", register_python_classifier); | ||
} | ||
}; | ||
|
||
static PythonClassifierRegisterer python_classifier_registerer; | ||
|
||
} // namespace mmdeploy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
#include "apis/python/common.h" | ||
|
||
namespace mmdeploy { | ||
|
||
std::map<std::string, void (*)(py::module &)> &gPythonBindings() { | ||
static std::map<std::string, void (*)(py::module &)> v; | ||
return v; | ||
} | ||
|
||
mm_mat_t GetMat(const PyImage &img) { | ||
auto info = img.request(); | ||
if (info.ndim != 3) { | ||
fprintf(stderr, "info.ndim = %d\n", (int)info.ndim); | ||
throw std::runtime_error("continuous uint8 HWC array expected"); | ||
} | ||
auto channels = (int)info.shape[2]; | ||
mm_mat_t mat{}; | ||
if (channels == 1) { | ||
mat.format = MM_GRAYSCALE; | ||
} else if (channels == 3) { | ||
mat.format = MM_BGR; | ||
} else { | ||
throw std::runtime_error("images of 1 or 3 channels are supported"); | ||
} | ||
mat.height = (int)info.shape[0]; | ||
mat.width = (int)info.shape[1]; | ||
mat.channel = channels; | ||
mat.type = MM_INT8; | ||
mat.data = (uint8_t *)info.ptr; | ||
return mat; | ||
} | ||
|
||
} // namespace mmdeploy | ||
|
||
PYBIND11_MODULE(mmdeploy_python, m) { | ||
for (const auto &[_, f] : mmdeploy::gPythonBindings()) { | ||
f(m); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
#ifndef MMDEPLOY_CSRC_APIS_PYTHON_COMMON_H_ | ||
#define MMDEPLOY_CSRC_APIS_PYTHON_COMMON_H_ | ||
|
||
#include <stdexcept> | ||
|
||
#include "apis/c/common.h" | ||
#include "pybind11/numpy.h" | ||
#include "pybind11/pybind11.h" | ||
#include "pybind11/stl.h" | ||
|
||
namespace py = pybind11; | ||
|
||
using PyImage = py::array_t<uint8_t, py::array::c_style | py::array::forcecast>; | ||
|
||
namespace mmdeploy { | ||
|
||
std::map<std::string, void (*)(py::module &)> &gPythonBindings(); | ||
|
||
mm_mat_t GetMat(const PyImage &img); | ||
|
||
} // namespace mmdeploy | ||
|
||
#endif // MMDEPLOY_CSRC_APIS_PYTHON_COMMON_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
#include "detector.h" | ||
|
||
#include "common.h" | ||
|
||
namespace mmdeploy { | ||
|
||
class PyDetector { | ||
public: | ||
PyDetector(const char *model_path, const char *device_name, int device_id) { | ||
auto status = mmdeploy_detector_create_by_path(model_path, device_name, device_id, &handle_); | ||
if (status != MM_SUCCESS) { | ||
throw std::runtime_error("failed to create detector"); | ||
} | ||
} | ||
py::list Apply(const std::vector<PyImage> &imgs) { | ||
std::vector<mm_mat_t> mats; | ||
mats.reserve(imgs.size()); | ||
for (const auto &img : imgs) { | ||
auto mat = GetMat(img); | ||
mats.push_back(mat); | ||
} | ||
mm_detect_t *detection{}; | ||
int *result_count{}; | ||
auto status = | ||
mmdeploy_detector_apply(handle_, mats.data(), (int)mats.size(), &detection, &result_count); | ||
if (status != MM_SUCCESS) { | ||
throw std::runtime_error("failed to apply detector, code: " + std::to_string(status)); | ||
} | ||
auto output = py::list{}; | ||
auto result = detection; | ||
for (int i = 0; i < mats.size(); ++i) { | ||
auto bboxes = py::array_t<float>({result_count[i], 5}); | ||
auto labels = py::array_t<int>(result_count[i]); | ||
auto masks = std::vector<py::array_t<uint8_t>>{}; | ||
masks.reserve(result_count[i]); | ||
for (int j = 0; j < result_count[i]; ++j, ++result) { | ||
auto bbox = bboxes.mutable_data(j); | ||
bbox[0] = result->bbox.left; | ||
bbox[1] = result->bbox.top; | ||
bbox[2] = result->bbox.right; | ||
bbox[3] = result->bbox.bottom; | ||
bbox[4] = result->score; | ||
labels.mutable_at(j) = result->label_id; | ||
if (result->mask) { | ||
py::array_t<uint8_t> mask({result->mask->height, result->mask->width}); | ||
memcpy(mask.mutable_data(), result->mask->data, mask.nbytes()); | ||
masks.push_back(std::move(mask)); | ||
} else { | ||
masks.emplace_back(); | ||
} | ||
} | ||
output.append(py::make_tuple(std::move(bboxes), std::move(labels), std::move(masks))); | ||
} | ||
mmdeploy_detector_release_result(detection, result_count, (int)mats.size()); | ||
return output; | ||
} | ||
~PyDetector() { | ||
mmdeploy_detector_destroy(handle_); | ||
handle_ = {}; | ||
} | ||
|
||
private: | ||
mm_handle_t handle_{}; | ||
}; | ||
|
||
static void register_python_detector(py::module &m) { | ||
py::class_<PyDetector>(m, "Detector") | ||
.def(py::init([](const char *model_path, const char *device_name, int device_id) { | ||
return std::make_unique<PyDetector>(model_path, device_name, device_id); | ||
})) | ||
.def("__call__", &PyDetector::Apply); | ||
} | ||
|
||
class PythonDetectorRegisterer { | ||
public: | ||
PythonDetectorRegisterer() { gPythonBindings().emplace("detector", register_python_detector); } | ||
}; | ||
|
||
static PythonDetectorRegisterer python_detector_registerer; | ||
|
||
} // namespace mmdeploy |
Oops, something went wrong.