diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h index f895af4150..4b27fbab8a 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h @@ -36,7 +36,7 @@ typedef struct mmdeploy_pose_tracker_param_t { int32_t pose_max_num_bboxes; // threshold for visible key-points, default = 0.5 float pose_kpt_thr; - // min number of key-points for valid poses, default = -1 + // min number of key-points for valid poses (-1 indicates ceil(n_kpts/2)), default = -1 int32_t pose_min_keypoints; // scale for expanding key-points to bbox, default = 1.25 float pose_bbox_scale; diff --git a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp index 313a37dc04..ba5823f558 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp @@ -119,10 +119,17 @@ int mmdeploy_segmentor_get_result(mmdeploy_value_t output, mmdeploy_segmentation results_ptr->height = segmentor_output.height; results_ptr->width = segmentor_output.width; results_ptr->classes = segmentor_output.classes; - auto mask_size = results_ptr->height * results_ptr->width; auto& mask = segmentor_output.mask; - results_ptr->mask = mask.data(); - buffers[i] = mask.buffer(); + auto& score = segmentor_output.score; + results_ptr->mask = nullptr; + results_ptr->score = nullptr; + if (mask.shape().size()) { + results_ptr->mask = mask.data(); + buffers[i] = mask.buffer(); + } else { + results_ptr->score = score.data(); + buffers[i] = score.buffer(); + } } *results = results_data; diff --git a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h index 7ae77a03f1..65bcfd03f3 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h @@ -17,11 +17,14 @@ extern "C" { #endif typedef struct mmdeploy_segmentation_t { - int height; ///< height of \p mask that equals to the input image's height - int width; ///< width of \p mask that equals to the input image's width - int classes; ///< the number of labels in \p mask - int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates - ///< the label id of pixel at (i, j) + int height; ///< height of \p mask that equals to the input image's height + int width; ///< width of \p mask that equals to the input image's width + int classes; ///< the number of labels in \p mask + int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates + ///< the label id of pixel at (i, j), this field might be null + float* score; ///< segmentation score map of the input image in CHW format, in which + ///< score[height * width * k + i * width + j] indicates the score + ///< of class k at pixel (i, j), this field might be null } mmdeploy_segmentation_t; typedef struct mmdeploy_segmentor* mmdeploy_segmentor_t; diff --git a/csrc/mmdeploy/apis/csharp/MMDeploy/APIs/Segmentor.cs b/csrc/mmdeploy/apis/csharp/MMDeploy/APIs/Segmentor.cs index 6470a2d8a8..c3b75ca603 100644 --- a/csrc/mmdeploy/apis/csharp/MMDeploy/APIs/Segmentor.cs +++ b/csrc/mmdeploy/apis/csharp/MMDeploy/APIs/Segmentor.cs @@ -10,6 +10,7 @@ internal unsafe struct CSegment public int Width; public int Classes; public int* Mask; + public float* Score; } #pragma warning restore 0649 @@ -34,10 +35,16 @@ public struct SegmentorOutput public int Classes; /// - /// Mask data. + /// Mask data, mask[i * width + j] indicates the label id of pixel at (i, j). /// public int[] Mask; + /// + /// Score data, score[height * width * k + i * width + j] indicates the score + /// of class k at pixel (i, j). + /// + public float[] Score; + /// /// Initializes a new instance of the struct. /// @@ -45,13 +52,31 @@ public struct SegmentorOutput /// width. /// classes. /// mask. - public SegmentorOutput(int height, int width, int classes, int[] mask) + /// score. + public SegmentorOutput(int height, int width, int classes, int[] mask, float[] score) { Height = height; Width = width; Classes = classes; - Mask = new int[Height * Width]; - Array.Copy(mask, this.Mask, mask.Length); + if (mask.Length > 0) + { + Mask = new int[Height * Width]; + Array.Copy(mask, this.Mask, mask.Length); + } + else + { + Mask = new int[] { }; + } + + if (score.Length > 0) + { + Score = new float[Height * Width * Classes]; + Array.Copy(score, this.Score, score.Length); + } + else + { + Score = new float[] { }; + } } internal unsafe SegmentorOutput(CSegment* result) @@ -59,11 +84,34 @@ internal unsafe SegmentorOutput(CSegment* result) Height = result->Height; Width = result->Width; Classes = result->Classes; - Mask = new int[Height * Width]; - int nbytes = Height * Width * sizeof(int); - fixed (int* data = this.Mask) + if (result->Mask != null) + { + Mask = new int[Height * Width]; + + int nbytes = Height * Width * sizeof(int); + fixed (int* data = this.Mask) + { + Buffer.MemoryCopy(result->Mask, data, nbytes, nbytes); + } + } + else + { + Mask = new int[] { }; + } + + if (result->Score != null) + { + Score = new float[Height * Width * Classes]; + + int nbytes = Height * Width * Classes * sizeof(float); + fixed (float* data = this.Score) + { + Buffer.MemoryCopy(result->Score, data, nbytes, nbytes); + } + } + else { - Buffer.MemoryCopy(result->Mask, data, nbytes, nbytes); + Score = new float[] { }; } } } diff --git a/csrc/mmdeploy/apis/cxx/CMakeLists.txt b/csrc/mmdeploy/apis/cxx/CMakeLists.txt index 19d56344a6..6a5160d9c6 100644 --- a/csrc/mmdeploy/apis/cxx/CMakeLists.txt +++ b/csrc/mmdeploy/apis/cxx/CMakeLists.txt @@ -24,4 +24,5 @@ install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/common.hpp install(DIRECTORY ${CMAKE_SOURCE_DIR}/demo/csrc/ DESTINATION example/cpp FILES_MATCHING PATTERN "*.cxx" + PATTERN "*.h" ) diff --git a/csrc/mmdeploy/apis/python/segmentor.cpp b/csrc/mmdeploy/apis/python/segmentor.cpp index 1fdf719fc8..940972ab61 100644 --- a/csrc/mmdeploy/apis/python/segmentor.cpp +++ b/csrc/mmdeploy/apis/python/segmentor.cpp @@ -37,12 +37,22 @@ class PySegmentor { std::vector rets(mats.size()); for (size_t i = 0; i < mats.size(); ++i) { - rets[i] = { - {segm[i].height, segm[i].width}, // shape - segm[i].mask, // data - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); }) // - }; + if (segm[i].mask != nullptr) { + rets[i] = { + {segm[i].height, segm[i].width}, // shape + segm[i].mask, // mask + py::capsule(new Sptr(holder), // handle + [](void* p) { delete reinterpret_cast(p); }) // + }; + } + if (segm[i].score != nullptr) { + rets[i] = { + {segm[i].classes, segm[i].height, segm[i].width}, // shape + segm[i].score, // score + py::capsule(new Sptr(holder), // handle + [](void* p) { delete reinterpret_cast(p); }) // + }; + } } return rets; } diff --git a/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt b/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt index 7aa25aebe2..2ea41f7271 100644 --- a/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt +++ b/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt @@ -4,8 +4,7 @@ project(mmdeploy_mmaction) file(GLOB SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp") mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") -add_subdirectory(cpu) -add_subdirectory(cuda) + target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_operation mmdeploy_transform diff --git a/csrc/mmdeploy/codebase/mmaction/cpu/CMakeLists.txt b/csrc/mmdeploy/codebase/mmaction/cpu/CMakeLists.txt deleted file mode 100644 index c815528397..0000000000 --- a/csrc/mmdeploy/codebase/mmaction/cpu/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -project(mmdeploy_mmaction_cpu_impl CXX) - -if ("cpu" IN_LIST MMDEPLOY_TARGET_DEVICES) - add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp) - set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1) - if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC)) - target_compile_options(${PROJECT_NAME} PRIVATE $<$:-fvisibility=hidden>) - endif () - target_link_libraries(${PROJECT_NAME} PRIVATE - mmdeploy::core) - target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME}) - mmdeploy_export(${PROJECT_NAME}) -endif () diff --git a/csrc/mmdeploy/codebase/mmaction/cpu/format_shape_impl.cpp b/csrc/mmdeploy/codebase/mmaction/cpu/format_shape_impl.cpp deleted file mode 100644 index 0a6900cbf3..0000000000 --- a/csrc/mmdeploy/codebase/mmaction/cpu/format_shape_impl.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "mmdeploy/codebase/mmaction/format_shape.h" -#include "mmdeploy/core/utils/device_utils.h" - -using namespace std; - -namespace mmdeploy::mmaction::cpu { - -class FormatShapeImpl : public FormatShapeOp { - public: - explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {} - - protected: - Device host_{0, 0}; - - const Device& GetDevice() { return host_; } - - Result Transpose(Tensor& src, const TensorShape& src_dims, - const std::vector& permutation) { - Tensor dst(src.desc()); - TensorShape shape(src.shape().size()); - for (int i = 0; i < shape.size(); i++) { - shape[i] = src.shape(permutation[i]); - } - dst.Reshape(shape); - int ndim = shape.size(); - std::vector dst_strides(ndim); - std::vector src_strides(ndim); - dst_strides[ndim - 1] = src_strides[ndim - 1] = 1; - for (int i = ndim - 2; i >= 0; i--) { - dst_strides[i] = dst_strides[i + 1] * shape[i + 1]; - src_strides[i] = src_strides[i + 1] * src_dims[i + 1]; - } - std::vector tmp(ndim); - for (int i = 0; i < ndim; i++) { - tmp[i] = src_strides[permutation[i]]; - } - src_strides.swap(tmp); - std::vector coord(ndim, 0); - auto dst_data = dst.data(); - auto src_data = src.data(); - - int i; - do { - dst_data[0] = src_data[0]; - for (i = ndim - 1; i >= 0; i--) { - if (++coord[i] == shape[i]) { - coord[i] = 0; - dst_data -= (shape[i] - 1) * dst_strides[i]; - src_data -= (shape[i] - 1) * src_strides[i]; - } else { - dst_data += dst_strides[i]; - src_data += src_strides[i]; - break; - } - } - } while (i >= 0); - return dst; - } -}; - -MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cpu, 0), [](std::string input_format) { - return std::make_unique(std::move(input_format)); -}); - -} // namespace mmdeploy::mmaction::cpu diff --git a/csrc/mmdeploy/codebase/mmaction/cuda/CMakeLists.txt b/csrc/mmdeploy/codebase/mmaction/cuda/CMakeLists.txt deleted file mode 100644 index 9502a33960..0000000000 --- a/csrc/mmdeploy/codebase/mmaction/cuda/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -if (NOT "cuda" IN_LIST MMDEPLOY_TARGET_DEVICES) - return() -endif () - -project(mmdeploy_mmaction_cuda_impl CXX) - -add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp transpose.cu) -set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1) -if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC)) - target_compile_options(${PROJECT_NAME} PRIVATE $<$:-fvisibility=hidden>) -endif () -target_include_directories(${PROJECT_NAME} PRIVATE - ${CUDA_INCLUDE_DIRS}) -target_link_libraries(${PROJECT_NAME} PRIVATE - mmdeploy::core) -target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME}) -mmdeploy_export(${PROJECT_NAME}) diff --git a/csrc/mmdeploy/codebase/mmaction/cuda/format_shape_impl.cpp b/csrc/mmdeploy/codebase/mmaction/cuda/format_shape_impl.cpp deleted file mode 100644 index 391b5afc6c..0000000000 --- a/csrc/mmdeploy/codebase/mmaction/cuda/format_shape_impl.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "cuda_runtime.h" -#include "mmdeploy/codebase/mmaction/format_shape.h" -#include "mmdeploy/core/utils/device_utils.h" - -using namespace std; - -namespace mmdeploy::mmaction::cuda { - -template -void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim, - int total, cudaStream_t stream); - -class FormatShapeImpl : public FormatShapeOp { - public: - explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {} - - protected: - const Device& GetDevice() { return device(); } - - Result Transpose(Tensor& src, const TensorShape& src_dims, - const std::vector& permutation) { - Tensor dst(src.desc()); - TensorShape shape(src.shape().size()); - for (int i = 0; i < shape.size(); i++) { - shape[i] = src.shape(permutation[i]); - } - dst.Reshape(shape); - - auto ndim = src_dims.size(); - std::vector dst_dims(ndim); - for (int i = 0; i < ndim; i++) { - dst_dims[i] = src_dims[permutation[i]]; - } - - std::vector src_strides(ndim); - std::vector dst_strides(ndim); - std::vector buffer(ndim); - buffer.back() = 1; - dst_strides.back() = 1; - for (int i = ndim - 1; i > 0; i--) { - buffer[i - 1] = buffer[i] * src_dims[i]; - dst_strides[i - 1] = dst_strides[i] * dst_dims[i]; - } - for (int i = 0; i < ndim; ++i) { - src_strides[i] = buffer[permutation[i]]; - } - - Buffer _src_strides(Device("cuda"), sizeof(int) * ndim); - Buffer _dst_strides(Device("cuda"), sizeof(int) * ndim); - OUTCOME_TRY(stream().Copy(src_strides.data(), _src_strides)); - OUTCOME_TRY(stream().Copy(dst_strides.data(), _dst_strides)); - - ::mmdeploy::mmaction::cuda::Transpose(src.data(), GetNative(_src_strides), - dst.data(), GetNative(_dst_strides), ndim, - src.size(), (cudaStream_t)stream().GetNative()); - return dst; - } -}; - -MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cuda, 0), [](std::string input_format) { - return std::make_unique(std::move(input_format)); -}); - -} // namespace mmdeploy::mmaction::cuda diff --git a/csrc/mmdeploy/codebase/mmaction/cuda/transpose.cu b/csrc/mmdeploy/codebase/mmaction/cuda/transpose.cu deleted file mode 100644 index bef1d447e8..0000000000 --- a/csrc/mmdeploy/codebase/mmaction/cuda/transpose.cu +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include -#include - -namespace mmdeploy { -namespace mmaction { -namespace cuda { - -template -__global__ void transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, - int ndim, int total) { - int u = blockIdx.x * blockDim.x + threadIdx.x; - if (u >= total) { - return; - } - - int remaining = u; - int v = 0; - for (int i = 0; i < ndim; i++) { - int p = remaining / dst_strides[i]; - remaining -= p * dst_strides[i]; - v += p * src_strides[i]; - } - dst[u] = src[v]; -} - -template -void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim, - int total, cudaStream_t stream) { - int thread_num = 256; - int block_num = (total + thread_num - 1) / thread_num; - transpose - <<>>(src, src_strides, dst, dst_strides, ndim, total); -} - -template void Transpose(const float* src, const int* src_strides, float* dst, - const int* dst_strides, int ndim, int total, cudaStream_t stream); - -} // namespace cuda -} // namespace mmaction -} // namespace mmdeploy diff --git a/csrc/mmdeploy/codebase/mmaction/format_shape.cpp b/csrc/mmdeploy/codebase/mmaction/format_shape.cpp index 81d9ac478f..7d8c6ac5c6 100644 --- a/csrc/mmdeploy/codebase/mmaction/format_shape.cpp +++ b/csrc/mmdeploy/codebase/mmaction/format_shape.cpp @@ -10,44 +10,23 @@ using namespace std; namespace mmdeploy::mmaction { FormatShape::FormatShape(const Value& args) { - auto input_format = args.value("input_format", std::string("")); - if (input_format != "NCHW" && input_format != "NCTHW") { + input_format_ = args.value("input_format", std::string("")); + if (input_format_ != "NCHW" && input_format_ != "NCTHW") { MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'"); throw_exception(eInvalidArgument); } - format_ = operation::Managed::Create(input_format); + permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); } -Result FormatShapeOp::apply(const std::vector& images, Tensor& output, int clip_len, - int num_clips) { - Tensor inputs; - OUTCOME_TRY(MergeInputs(images, inputs)); - if (GetDevice().is_host()) { - OUTCOME_TRY(stream().Wait()); - } - - // Tensor dst; - if (input_format_ == "NCHW") { - OUTCOME_TRY(output, FormatNCHW(inputs, clip_len, num_clips)); - } - if (input_format_ == "NCTHW") { - OUTCOME_TRY(output, FormatNCTHW(inputs, clip_len, num_clips)); - } - - TensorShape expand_dim = output.shape(); - expand_dim.insert(expand_dim.begin(), 1); - output.Reshape(expand_dim); - - return success(); -} - -Result FormatShapeOp::MergeInputs(const std::vector& images, Tensor& inputs) { +Result FormatShape::MergeInputs(const std::vector& images, Tensor& inputs) { auto N = static_cast(images.size()); auto H = images[0].shape(1); auto W = images[0].shape(2); auto C = images[0].shape(3); + auto& device = operation::gContext().device(); + auto& stream = operation::gContext().stream(); - TensorDesc desc = {GetDevice(), DataType::kFLOAT, {N, H, W, C}}; + TensorDesc desc = {device, DataType::kFLOAT, {N, H, W, C}}; inputs = Tensor(desc); auto offset = 0UL; auto n_item = H * W * C; @@ -55,21 +34,39 @@ Result FormatShapeOp::MergeInputs(const std::vector& images, Tenso for (int i = 0; i < N; i++) { auto src_buffer = images[i].buffer(); auto dst_buffer = inputs.buffer(); - OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset)); + OUTCOME_TRY(stream.Copy(src_buffer, dst_buffer, copy_size, 0, offset)); offset += copy_size; } return success(); } -Result FormatShapeOp::FormatNCHW(Tensor& src, int clip_len, int num_clips) { - auto N = src.shape(0); - auto H = src.shape(1); - auto W = src.shape(2); - auto C = src.shape(3); - return Transpose(src, {N, H, W, C}, {0, 3, 1, 2}); +Result FormatShape::Format(const std::vector& images, Tensor& output, int clip_len, + int num_clips) { + Tensor inputs; + OUTCOME_TRY(MergeInputs(images, inputs)); + + // Tensor dst; + if (input_format_ == "NCHW") { + OUTCOME_TRY(FormatNCHW(inputs, clip_len, num_clips, output)); + } + if (input_format_ == "NCTHW") { + OUTCOME_TRY(FormatNCTHW(inputs, clip_len, num_clips, output)); + } + + TensorShape expand_dim = output.shape(); + expand_dim.insert(expand_dim.begin(), 1); + output.Reshape(expand_dim); + + return success(); +} + +Result FormatShape::FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) { + const vector axes = {0, 3, 1, 2}; + OUTCOME_TRY(permute_.Apply(src, dst, axes)); + return success(); } -Result FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_clips) { +Result FormatShape::FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) { auto N = src.shape(0); auto H = src.shape(1); auto W = src.shape(2); @@ -80,8 +77,9 @@ Result FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_cli } int M = N / L; src.Reshape({M, L, H, W, C}); - - return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3}); + const vector axes = {0, 4, 1, 2, 3}; + OUTCOME_TRY(permute_.Apply(src, dst, axes)); + return success(); } Result FormatShape::Apply(Value& data) { @@ -119,7 +117,7 @@ Result FormatShape::Apply(Value& data) { Tensor dst; data = Value{}; - OUTCOME_TRY(format_.Apply(images, dst, clip_len, num_clips)); + OUTCOME_TRY(Format(images, dst, clip_len, num_clips)); data["img"] = std::move(dst); return success(); @@ -127,6 +125,4 @@ Result FormatShape::Apply(Value& data) { MMDEPLOY_REGISTER_TRANSFORM(FormatShape); -MMDEPLOY_DEFINE_REGISTRY(FormatShapeOp); - } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/format_shape.h b/csrc/mmdeploy/codebase/mmaction/format_shape.h index 13b6648784..97e4f99356 100644 --- a/csrc/mmdeploy/codebase/mmaction/format_shape.h +++ b/csrc/mmdeploy/codebase/mmaction/format_shape.h @@ -1,51 +1,39 @@ // Copyright (c) OpenMMLab. All rights reserved. -#ifndef MMDEPLOY_SRC_CODEBASE_MMACTION_FORMAT_SHAPE_H_ -#define MMDEPLOY_SRC_CODEBASE_MMACTION_FORMAT_SHAPE_H_ +#ifndef MMDEPLOY_CODEBASE_MMACTION_FORMAT_SHAPE_H_ +#define MMDEPLOY_CODEBASE_MMACTION_FORMAT_SHAPE_H_ #include +#include #include #include "mmdeploy/core/tensor.h" #include "mmdeploy/operation/managed.h" +#include "mmdeploy/operation/vision.h" #include "mmdeploy/preprocess/transform/transform.h" namespace mmdeploy::mmaction { -class FormatShapeOp : public operation::Operation { +class FormatShape : public Transform { public: - explicit FormatShapeOp(std::string input_format) : input_format_(std::move(input_format)){}; - - Result apply(const std::vector& inputs, Tensor& output, int clip_len, - int num_clips); + explicit FormatShape(const Value& args); - virtual const Device& GetDevice() = 0; + Result Apply(Value& data) override; - virtual Result Transpose(Tensor& src, const TensorShape& src_dims, - const std::vector& permutation) = 0; + Result Format(const std::vector& images, Tensor& output, int clip_len, + int num_clips); - Result FormatNCHW(Tensor& src, int clip_len, int num_clips); + Result FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); - Result FormatNCTHW(Tensor& src, int clip_len, int num_clips); + Result FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); Result MergeInputs(const std::vector& images, Tensor& inputs); - protected: - std::string input_format_; -}; - -class FormatShape : public Transform { - public: - explicit FormatShape(const Value& args); - - Result Apply(Value& data) override; - private: - operation::Managed format_; + std::string input_format_; + operation::Managed permute_; }; -MMDEPLOY_DECLARE_REGISTRY(FormatShapeOp, std::unique_ptr(std::string input_format)); - } // namespace mmdeploy::mmaction #endif diff --git a/csrc/mmdeploy/codebase/mmaction/mmaction.h b/csrc/mmdeploy/codebase/mmaction/mmaction.h index 238b780962..ef097e6f20 100644 --- a/csrc/mmdeploy/codebase/mmaction/mmaction.h +++ b/csrc/mmdeploy/codebase/mmaction/mmaction.h @@ -1,7 +1,7 @@ // Copyright (c) OpenMMLab. All rights reserved. -#ifndef MMDEPLOY_SRC_CODEBASE_MMACTION_MMACTION_H_ -#define MMDEPLOY_SRC_CODEBASE_MMACTION_MMACTION_H_ +#ifndef MMDEPLOY_CODEBASE_MMACTION_MMACTION_H_ +#define MMDEPLOY_CODEBASE_MMACTION_MMACTION_H_ #include "mmdeploy/codebase/common.h" #include "mmdeploy/core/device.h" diff --git a/csrc/mmdeploy/codebase/mmpose/pose_tracker/utils.h b/csrc/mmdeploy/codebase/mmpose/pose_tracker/utils.h index 3643c99387..676e87157d 100644 --- a/csrc/mmdeploy/codebase/mmpose/pose_tracker/utils.h +++ b/csrc/mmdeploy/codebase/mmpose/pose_tracker/utils.h @@ -22,7 +22,7 @@ using Points = vector; using Score = float; using Scores = vector; -#define POSE_TRACKER_DEBUG(...) MMDEPLOY_INFO(__VA_ARGS__) +#define POSE_TRACKER_DEBUG(...) MMDEPLOY_DEBUG(__VA_ARGS__) // opencv3 can't construct cv::Mat from std::array template diff --git a/csrc/mmdeploy/codebase/mmseg/CMakeLists.txt b/csrc/mmdeploy/codebase/mmseg/CMakeLists.txt index fc05b1e20d..aac2376346 100644 --- a/csrc/mmdeploy/codebase/mmseg/CMakeLists.txt +++ b/csrc/mmdeploy/codebase/mmseg/CMakeLists.txt @@ -4,7 +4,9 @@ project(mmdeploy_mmseg) file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp") mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") -target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils) +target_link_libraries(${PROJECT_NAME} PRIVATE + mmdeploy_opencv_utils + mmdeploy_operation) add_library(mmdeploy::mmseg ALIAS ${PROJECT_NAME}) set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} segmentor CACHE INTERNAL "") diff --git a/csrc/mmdeploy/codebase/mmseg/mmseg.h b/csrc/mmdeploy/codebase/mmseg/mmseg.h index 8a9c0f6d78..8f55fadce1 100644 --- a/csrc/mmdeploy/codebase/mmseg/mmseg.h +++ b/csrc/mmdeploy/codebase/mmseg/mmseg.h @@ -12,10 +12,11 @@ namespace mmdeploy::mmseg { struct SegmentorOutput { Tensor mask; + Tensor score; int height; int width; int classes; - MMDEPLOY_ARCHIVE_MEMBERS(mask, height, width, classes); + MMDEPLOY_ARCHIVE_MEMBERS(mask, score, height, width, classes); }; MMDEPLOY_DECLARE_CODEBASE(MMSegmentation, mmseg); diff --git a/csrc/mmdeploy/codebase/mmseg/segment.cpp b/csrc/mmdeploy/codebase/mmseg/segment.cpp index b1128886c2..56811a4fad 100644 --- a/csrc/mmdeploy/codebase/mmseg/segment.cpp +++ b/csrc/mmdeploy/codebase/mmseg/segment.cpp @@ -5,6 +5,8 @@ #include "mmdeploy/core/tensor.h" #include "mmdeploy/core/utils/device_utils.h" #include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy/operation/managed.h" +#include "mmdeploy/operation/vision.h" #include "mmdeploy/preprocess/transform/transform.h" #include "opencv_utils.h" @@ -18,7 +20,10 @@ class ResizeMask : public MMSegmentation { explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) { try { classes_ = cfg["params"]["num_classes"].get(); + with_argmax_ = cfg["params"].value("with_argmax", true); little_endian_ = IsLittleEndian(); + ::mmdeploy::operation::Context ctx(Device("cpu"), stream_); + permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); } catch (const std::exception &e) { MMDEPLOY_ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg); throw_exception(eInvalidArgument); @@ -31,40 +36,71 @@ class ResizeMask : public MMSegmentation { auto mask = inference_result["output"].get(); MMDEPLOY_DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(), mask.data_type()); - if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) { + if (!(mask.shape().size() == 4 && mask.shape(0) == 1)) { MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}", mask.shape()); return Status(eNotSupported); } + if ((mask.shape(1) != 1) && with_argmax_) { + MMDEPLOY_ERROR("probability feat map with shape: {} requires `with_argmax_=false`", + mask.shape()); + return Status(eNotSupported); + } + if ((mask.data_type() != DataType::kFLOAT) && !with_argmax_) { + MMDEPLOY_ERROR("probability feat map only support float32 output"); + return Status(eNotSupported); + } + auto channel = (int)mask.shape(1); auto height = (int)mask.shape(2); auto width = (int)mask.shape(3); auto input_height = preprocess_result["img_metas"]["ori_shape"][1].get(); auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get(); Device host{"cpu"}; OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_)); - OUTCOME_TRY(stream_.Wait()); + if (!with_argmax_) { + // (C, H, W) -> (H, W, C) + ::mmdeploy::operation::Context ctx(host, stream_); + std::vector axes = {0, 2, 3, 1}; + OUTCOME_TRY(permute_.Apply(host_tensor, host_tensor, axes)); + } - OUTCOME_TRY(auto cv_type, GetCvType(mask.data_type())); + OUTCOME_TRY(auto cv_type, GetCvType(mask.data_type(), channel)); cv::Mat mask_mat(height, width, cv_type, host_tensor.data()); - if (mask_mat.channels() > 1) { - cv::extractChannel(mask_mat, mask_mat, little_endian_ ? 0 : mask_mat.channels() - 1); - } - if (mask_mat.type() != CV_32S) { - mask_mat.convertTo(mask_mat, CV_32S); - } + cv::Mat resized_mask; + cv::Mat resized_score; - cv::Mat resized_mask = cpu::Resize(mask_mat, input_height, input_width, "nearest"); + Tensor tensor_mask{}; + Tensor tensor_score{}; + + if (with_argmax_) { + // mask + if (mask_mat.channels() > 1) { + cv::extractChannel(mask_mat, mask_mat, little_endian_ ? 0 : mask_mat.channels() - 1); + } + if (mask_mat.type() != CV_32S) { + mask_mat.convertTo(mask_mat, CV_32S); + } + resized_mask = cpu::Resize(mask_mat, input_height, input_width, "nearest"); + tensor_mask = cpu::CVMat2Tensor(resized_mask); + } else { + // score + resized_score = cpu::Resize(mask_mat, input_height, input_width, "bilinear"); + tensor_score = cpu::CVMat2Tensor(resized_score); + std::vector axes = {0, 3, 1, 2}; + ::mmdeploy::operation::Context ctx(host, stream_); + OUTCOME_TRY(permute_.Apply(tensor_score, tensor_score, axes)); + } - SegmentorOutput output{cpu::CVMat2Tensor(resized_mask), input_height, input_width, classes_}; + SegmentorOutput output{tensor_mask, tensor_score, input_height, input_width, classes_}; return to_value(output); } private: - static Result GetCvType(DataType type) { + static Result GetCvType(DataType type, int channel) { switch (type) { case DataType::kFLOAT: - return CV_32F; + return CV_32FC(channel); case DataType::kINT64: return CV_32SC2; case DataType::kINT32: @@ -84,7 +120,9 @@ class ResizeMask : public MMSegmentation { } protected: + ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute> permute_; int classes_{}; + bool with_argmax_{true}; bool little_endian_; }; diff --git a/csrc/mmdeploy/operation/cpu/CMakeLists.txt b/csrc/mmdeploy/operation/cpu/CMakeLists.txt index fa6a1de56c..5607123deb 100644 --- a/csrc/mmdeploy/operation/cpu/CMakeLists.txt +++ b/csrc/mmdeploy/operation/cpu/CMakeLists.txt @@ -11,7 +11,8 @@ set(SRCS resize.cpp crop.cpp flip.cpp warp_affine.cpp - crop_resize_pad.cpp) + crop_resize_pad.cpp + permute.cpp) mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") diff --git a/csrc/mmdeploy/operation/cpu/permute.cpp b/csrc/mmdeploy/operation/cpu/permute.cpp new file mode 100644 index 0000000000..44c98fe24d --- /dev/null +++ b/csrc/mmdeploy/operation/cpu/permute.cpp @@ -0,0 +1,91 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "mmdeploy/operation/vision.h" +#include "mmdeploy/utils/opencv/opencv_utils.h" + +namespace mmdeploy::operation::cpu { + +class PermuteImpl : public Permute { + public: + explicit PermuteImpl() {} + + Result apply(const Tensor& src, Tensor& dst, const std::vector& axes) override { + int ndim = src.shape().size(); + if (ndim != axes.size()) { + MMDEPLOY_ERROR("The size of axes should be equal to src, {} vs {}", axes.size(), ndim); + return Status(eInvalidArgument); + } + std::vector axes_vis(ndim, 0); + for (const auto& x : axes) { + if (x < 0 || x >= ndim || axes_vis[x]) { + MMDEPLOY_ERROR("Invalid axes"); + return Status(eInvalidArgument); + } + axes_vis[x] = 1; + } + + Tensor dst_tensor(src.desc()); + auto src_dims = src.shape(); + TensorShape dst_dims(ndim); + for (int i = 0; i < src_dims.size(); i++) { + dst_dims[i] = src_dims[axes[i]]; + } + dst_tensor.Reshape(dst_dims); + + std::vector dst_strides(ndim); + std::vector src_strides(ndim); + dst_strides[ndim - 1] = src_strides[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; i--) { + dst_strides[i] = dst_strides[i + 1] * dst_dims[i + 1]; + src_strides[i] = src_strides[i + 1] * src_dims[i + 1]; + } + + std::vector tmp(ndim); + for (int i = 0; i < ndim; i++) { + tmp[i] = src_strides[axes[i]]; + } + src_strides.swap(tmp); + + if (src.data_type() == DataType::kINT8) { + OUTCOME_TRY(PermuteDispatch(src, dst_tensor, src_strides, dst_strides)); + } else if (src.data_type() == DataType::kFLOAT) { + OUTCOME_TRY(PermuteDispatch(src, dst_tensor, src_strides, dst_strides)); + } else { + MMDEPLOY_ERROR("unsupported data type {}", src.data_type()); + return Status(eNotSupported); + } + dst = std::move(dst_tensor); + return success(); + } + + template + Result PermuteDispatch(const Tensor& src, Tensor& dst, const std::vector& src_strides, + const std::vector& dst_strides) { + auto shape = dst.shape(); + int ndim = src.shape().size(); + std::vector coord(ndim, 0); + auto dst_data = dst.data(); + auto src_data = src.data(); + + int i; + do { + dst_data[0] = src_data[0]; + for (i = ndim - 1; i >= 0; i--) { + if (++coord[i] == shape[i]) { + coord[i] = 0; + dst_data -= (shape[i] - 1) * dst_strides[i]; + src_data -= (shape[i] - 1) * src_strides[i]; + } else { + dst_data += dst_strides[i]; + src_data += src_strides[i]; + break; + } + } + } while (i >= 0); + return success(); + } +}; + +MMDEPLOY_REGISTER_FACTORY_FUNC(Permute, (cpu, 0), []() { return std::make_unique(); }); + +} // namespace mmdeploy::operation::cpu diff --git a/csrc/mmdeploy/operation/cuda/CMakeLists.txt b/csrc/mmdeploy/operation/cuda/CMakeLists.txt index 8322b9d2ad..551f89977b 100644 --- a/csrc/mmdeploy/operation/cuda/CMakeLists.txt +++ b/csrc/mmdeploy/operation/cuda/CMakeLists.txt @@ -19,7 +19,9 @@ set(SRCS resize.cpp crop.cu flip.cpp warp_affine.cpp - crop_resize_pad.cpp) + crop_resize_pad.cpp + permute.cpp + permute.cu) mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") diff --git a/csrc/mmdeploy/operation/cuda/permute.cpp b/csrc/mmdeploy/operation/cuda/permute.cpp new file mode 100644 index 0000000000..c5c87da881 --- /dev/null +++ b/csrc/mmdeploy/operation/cuda/permute.cpp @@ -0,0 +1,91 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "mmdeploy/operation/cuda/permute.h" + +#include + +#include "mmdeploy/operation/vision.h" + +namespace mmdeploy::operation::cuda { + +namespace impl { +template +void Permute(const T* src, const TensorStride& src_strides, T* dst, const TensorStride& dst_strides, + int ndim, int total, cudaStream_t stream); +} + +class PermuteImpl : public Permute { + public: + explicit PermuteImpl() {} + + Result apply(const Tensor& src, Tensor& dst, const std::vector& axes) override { + int ndim = src.shape().size(); + if (ndim != axes.size()) { + MMDEPLOY_ERROR("The size of axes should be equal of src, {} vs {}", axes.size(), ndim); + return Status(eInvalidArgument); + } + if (ndim > MAX_PERMUTE_DIM) { + MMDEPLOY_ERROR("Only support ndim < {}", MAX_PERMUTE_DIM); + return Status(eInvalidArgument); + } + std::vector axes_vis(ndim, 0); + for (const auto& x : axes) { + if (x < 0 || x >= ndim || axes_vis[x]) { + MMDEPLOY_ERROR("Invalid axes"); + return Status(eInvalidArgument); + } + axes_vis[x] = 1; + } + + Tensor dst_tensor(src.desc()); + auto src_dims = src.shape(); + TensorShape dst_dims(ndim); + for (int i = 0; i < src_dims.size(); i++) { + dst_dims[i] = src_dims[axes[i]]; + } + dst_tensor.Reshape(dst_dims); + + TensorStride dst_strides; + TensorStride src_strides; + + dst_strides[ndim - 1] = src_strides[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; i--) { + dst_strides[i] = dst_strides[i + 1] * dst_dims[i + 1]; + src_strides[i] = src_strides[i + 1] * src_dims[i + 1]; + } + + TensorStride tmp; + for (int i = 0; i < ndim; i++) { + tmp[i] = src_strides[axes[i]]; + } + src_strides = tmp; + + if (src.data_type() == DataType::kINT8) { + OUTCOME_TRY(PermuteDispatch(src, dst_tensor, src_strides, dst_strides)); + } else if (src.data_type() == DataType::kFLOAT) { + OUTCOME_TRY(PermuteDispatch(src, dst_tensor, src_strides, dst_strides)); + } else { + MMDEPLOY_ERROR("unsupported data type {}", src.data_type()); + return Status(eNotSupported); + } + dst = std::move(dst_tensor); + return success(); + } + + template + Result PermuteDispatch(const Tensor& src, Tensor& dst, const TensorStride& src_strides, + const TensorStride& dst_strides) { + auto src_data = src.data(); + auto dst_data = dst.data(); + auto ndim = src.shape().size(); + auto total = src.size(); + impl::Permute(src_data, src_strides, dst_data, dst_strides, ndim, total, + GetNative(stream())); + return success(); + } +}; + +MMDEPLOY_REGISTER_FACTORY_FUNC(Permute, (cuda, 0), + []() { return std::make_unique(); }); + +} // namespace mmdeploy::operation::cuda diff --git a/csrc/mmdeploy/operation/cuda/permute.cu b/csrc/mmdeploy/operation/cuda/permute.cu new file mode 100644 index 0000000000..7f979ed3fc --- /dev/null +++ b/csrc/mmdeploy/operation/cuda/permute.cu @@ -0,0 +1,49 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "mmdeploy/operation/cuda/permute.h" + +namespace mmdeploy { +namespace operation { +namespace cuda { +namespace impl { + +template +__global__ void permute(const T* src, const TensorStride src_strides, T* dst, + const TensorStride dst_strides, int ndim, int total) { + int u = blockIdx.x * blockDim.x + threadIdx.x; + if (u >= total) { + return; + } + + int remaining = u; + int v = 0; + for (size_t i = 0; i < ndim; i++) { + int p = remaining / dst_strides.v_[i]; + remaining -= p * dst_strides.v_[i]; + v += p * src_strides.v_[i]; + } + dst[u] = src[v]; +} + +template +void Permute(const T* src, const TensorStride& src_strides, T* dst, const TensorStride& dst_strides, + int ndim, int total, cudaStream_t stream) { + int thread_num = 256; + int block_num = (total + thread_num - 1) / thread_num; + permute<<>>(src, src_strides, dst, dst_strides, ndim, total); +} + +template void Permute(const float* src, const TensorStride& src_strides, float* dst, + const TensorStride& dst_strides, int ndim, int total, + cudaStream_t stream); + +template void Permute(const uint8_t* src, const TensorStride& src_strides, uint8_t* dst, + const TensorStride& dst_strides, int ndim, int total, + cudaStream_t stream); + +} // namespace impl +} // namespace cuda +} // namespace operation +} // namespace mmdeploy diff --git a/csrc/mmdeploy/operation/cuda/permute.h b/csrc/mmdeploy/operation/cuda/permute.h new file mode 100644 index 0000000000..7bbc0a404f --- /dev/null +++ b/csrc/mmdeploy/operation/cuda/permute.h @@ -0,0 +1,24 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef MMDEPLOY_OPERATION_CUDA_PERMUTE_H_ +#define MMDEPLOY_OPERATION_CUDA_PERMUTE_H_ + +#include + +#include + +namespace mmdeploy { +namespace operation { +namespace cuda { + +const int MAX_PERMUTE_DIM = 8; + +struct TensorStride { + int v_[MAX_PERMUTE_DIM]; + int& operator[](size_t idx) { return v_[idx]; } +}; + +} // namespace cuda +} // namespace operation +} // namespace mmdeploy + +#endif // MMDEPLOY_OPERATION_CUDA_PERMUTE_H_ diff --git a/csrc/mmdeploy/operation/vision.cpp b/csrc/mmdeploy/operation/vision.cpp index 18694f06a5..0c0b13eb9a 100644 --- a/csrc/mmdeploy/operation/vision.cpp +++ b/csrc/mmdeploy/operation/vision.cpp @@ -14,5 +14,6 @@ MMDEPLOY_DEFINE_REGISTRY(Crop); MMDEPLOY_DEFINE_REGISTRY(Flip); MMDEPLOY_DEFINE_REGISTRY(WarpAffine); MMDEPLOY_DEFINE_REGISTRY(CropResizePad); +MMDEPLOY_DEFINE_REGISTRY(Permute); } // namespace mmdeploy::operation diff --git a/csrc/mmdeploy/operation/vision.h b/csrc/mmdeploy/operation/vision.h index 10e699fed4..013c3852b8 100644 --- a/csrc/mmdeploy/operation/vision.h +++ b/csrc/mmdeploy/operation/vision.h @@ -92,6 +92,11 @@ class CropResizePad : public Operation { }; MMDEPLOY_DECLARE_REGISTRY(CropResizePad, unique_ptr()); +class Permute : public Operation { + public: + virtual Result apply(const Tensor& src, Tensor& dst, const std::vector& axes) = 0; +}; +MMDEPLOY_DECLARE_REGISTRY(Permute, unique_ptr()); } // namespace mmdeploy::operation diff --git a/csrc/mmdeploy/utils/opencv/opencv_utils.cpp b/csrc/mmdeploy/utils/opencv/opencv_utils.cpp index d3cd1ad87b..4b4cb3a9cc 100644 --- a/csrc/mmdeploy/utils/opencv/opencv_utils.cpp +++ b/csrc/mmdeploy/utils/opencv/opencv_utils.cpp @@ -126,7 +126,7 @@ cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width, const std::string& interpolation) { cv::Mat dst(dst_height, dst_width, src.type()); auto method = GetInterpolationMethod(interpolation).value(); - cv::resize(src, dst, dst.size(), method); + cv::resize(src, dst, dst.size(), 0, 0, method); return dst; } diff --git a/demo/csharp/image_segmentation/Program.cs b/demo/csharp/image_segmentation/Program.cs index 1318c001f4..b2cfa34147 100644 --- a/demo/csharp/image_segmentation/Program.cs +++ b/demo/csharp/image_segmentation/Program.cs @@ -75,21 +75,50 @@ static void Main(string[] args) unsafe { byte* data = colorMask.DataPointer; - fixed (int* _label = output[0].Mask) + if (output[0].Mask.Length > 0) { - int* label = _label; - for (int i = 0; i < output[0].Height; i++) + fixed (int* _label = output[0].Mask) { - for (int j = 0; j < output[0].Width; j++) + int* label = _label; + for (int i = 0; i < output[0].Height; i++) { - data[0] = palette[*label][0]; - data[1] = palette[*label][1]; - data[2] = palette[*label][2]; - data += 3; - label++; + for (int j = 0; j < output[0].Width; j++) + { + data[0] = palette[*label][0]; + data[1] = palette[*label][1]; + data[2] = palette[*label][2]; + data += 3; + label++; + } } } } + else + { + int pos = 0; + fixed (float* _score = output[0].Score) + { + float *score = _score; + int total = output[0].Height * output[0].Width; + for (int i = 0; i < output[0].Height; i++) + { + for (int j = 0; j < output[0].Width; j++) + { + List> scores = new List>(); + for (int k = 0; k < output[0].Classes; k++) + { + scores.Add(new Tuple(score[k * total + i * output[0].Width + j], k)); + } + scores.Sort(); + data[0] = palette[scores[^1].Item2][0]; + data[1] = palette[scores[^1].Item2][1]; + data[2] = palette[scores[^1].Item2][2]; + data += 3; + } + } + } + } + } colorMask = imgs[0] * 0.5 + colorMask * 0.5; diff --git a/demo/csrc/c/image_segmentation.cpp b/demo/csrc/c/image_segmentation.cpp index fae446b4f0..df26d1585c 100644 --- a/demo/csrc/c/image_segmentation.cpp +++ b/demo/csrc/c/image_segmentation.cpp @@ -1,6 +1,7 @@ // Copyright (c) OpenMMLab. All rights reserved. #include +#include #include #include #include @@ -59,8 +60,25 @@ int main(int argc, char* argv[]) { cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3); int pos = 0; + int total = color_mask.rows * color_mask.cols; + std::vector idxs(result->classes); for (auto iter = color_mask.begin(); iter != color_mask.end(); ++iter) { - *iter = palette[result->mask[pos++]]; + // output mask + if (result->mask) { + *iter = palette[result->mask[pos++]]; + } + // output score + if (result->score) { + std::iota(idxs.begin(), idxs.end(), 0); + auto k = + std::max_element(idxs.begin(), idxs.end(), + [&](int i, int j) { + return result->score[i * total + pos] < result->score[j * total + pos]; + }) - + idxs.begin(); + *iter = palette[k]; + pos += 1; + } } img = img * 0.5 + color_mask * 0.5; diff --git a/demo/csrc/cpp/classifier.cxx b/demo/csrc/cpp/classifier.cxx index 2d8b2d1e25..a4a12eea16 100644 --- a/demo/csrc/cpp/classifier.cxx +++ b/demo/csrc/cpp/classifier.cxx @@ -1,31 +1,43 @@ #include "mmdeploy/classifier.hpp" -#include - #include "opencv2/imgcodecs/imgcodecs.hpp" +#include "utils/argparse.h" +#include "utils/visualize.h" + +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "classifier_output.jpg", "Output image path"); int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n image_classification device_name model_path image_path\n"); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; + + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; } - mmdeploy::Model model(model_path); - mmdeploy::Classifier classifier(model, mmdeploy::Device{device_name, 0}); + // construct a classifier instance + mmdeploy::Classifier classifier(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}); - auto res = classifier.Apply(img); + // apply the classifier; the result is an array-like class holding references to + // `mmdeploy_classification_t`, will be released automatically on destruction + mmdeploy::Classifier::Result result = classifier.Apply(img); + + // visualize results + utils::Visualize v; + auto sess = v.get_session(img); + int count = 0; + for (const mmdeploy_classification_t& cls : result) { + sess.add_label(cls.label_id, cls.score, count++); + } - for (const auto& cls : res) { - fprintf(stderr, "label: %d, score: %.4f\n", cls.label_id, cls.score); + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); } return 0; diff --git a/demo/csrc/cpp/det_pose.cpp b/demo/csrc/cpp/det_pose.cpp deleted file mode 100644 index a12dd3c0b1..0000000000 --- a/demo/csrc/cpp/det_pose.cpp +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include - -#include "mmdeploy/detector.hpp" -#include "mmdeploy/pose_detector.hpp" -#include "opencv2/imgcodecs/imgcodecs.hpp" -#include "opencv2/imgproc/imgproc.hpp" - -using std::vector; - -cv::Mat Visualize(cv::Mat frame, const std::vector& poses, int size); - -int main(int argc, char* argv[]) { - const auto device_name = argv[1]; - const auto det_model_path = argv[2]; - const auto pose_model_path = argv[3]; - const auto image_path = argv[4]; - - if (argc != 5) { - std::cerr << "usage:\n\tpose_tracker device_name det_model_path pose_model_path image_path\n"; - return -1; - } - auto img = cv::imread(image_path); - if (!img.data) { - std::cerr << "failed to load image: " << image_path << "\n"; - return -1; - } - - using namespace mmdeploy; - - Context context(Device{device_name}); // create context for holding the device handle - Detector detector(Model(det_model_path), context); // create object detector - PoseDetector pose(Model(pose_model_path), context); // create pose detector - - // apply detector - auto dets = detector.Apply(img); - - // filter detections and extract bboxes for pose model - std::vector bboxes; - for (const auto& det : dets) { - if (det.label_id == 0 && det.score > .6f) { - bboxes.push_back(det.bbox); - } - } - // apply pose detector - auto poses = pose.Apply(img, bboxes); - - // visualize - auto vis = Visualize(img, {poses.begin(), poses.end()}, 1280); - cv::imwrite("det_pose_output.jpg", vis); - - return 0; -} - -struct Skeleton { - vector> skeleton; - vector palette; - vector link_color; - vector point_color; -}; - -const Skeleton& gCocoSkeleton() { - static const Skeleton inst{ - { - {15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12}, - {5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1}, - {0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}, - }, - { - {255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255}, - {153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255}, - {255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102}, - {51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255}, - }, - {0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16}, - {16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0}, - }; - return inst; -} - -cv::Mat Visualize(cv::Mat frame, const vector& poses, int size) { - auto& [skeleton, palette, link_color, point_color] = gCocoSkeleton(); - auto scale = (float)size / (float)std::max(frame.cols, frame.rows); - if (scale != 1) { - cv::resize(frame, frame, {}, scale, scale); - } else { - frame = frame.clone(); - } - for (const auto& pose : poses) { - vector kpts(&pose.point[0].x, &pose.point[pose.length - 1].y + 1); - vector scores(pose.score, pose.score + pose.length); - std::for_each(kpts.begin(), kpts.end(), [&](auto& x) { x *= scale; }); - constexpr auto score_thr = .5f; - vector used(kpts.size()); - for (size_t i = 0; i < skeleton.size(); ++i) { - auto [u, v] = skeleton[i]; - if (scores[u] > score_thr && scores[v] > score_thr) { - used[u] = used[v] = 1; - cv::Point p_u(kpts[u * 2], kpts[u * 2 + 1]); - cv::Point p_v(kpts[v * 2], kpts[v * 2 + 1]); - cv::line(frame, p_u, p_v, palette[link_color[i]], 1, cv::LINE_AA); - } - } - for (size_t i = 0; i < kpts.size(); i += 2) { - if (used[i / 2]) { - cv::Point p(kpts[i], kpts[i + 1]); - cv::circle(frame, p, 1, palette[point_color[i / 2]], 2, cv::LINE_AA); - } - } - } - return frame; -} diff --git a/demo/csrc/cpp/det_pose.cxx b/demo/csrc/cpp/det_pose.cxx new file mode 100644 index 0000000000..b18fbe003b --- /dev/null +++ b/demo/csrc/cpp/det_pose.cxx @@ -0,0 +1,75 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "mmdeploy/detector.hpp" +#include "mmdeploy/pose_detector.hpp" +#include "opencv2/imgcodecs/imgcodecs.hpp" +#include "utils/argparse.h" +#include "utils/visualize.h" + +DEFINE_ARG_string(det_model, "Object detection model path"); +DEFINE_ARG_string(pose_model, "Pose estimation model path"); +DEFINE_ARG_string(image, "Input image path"); + +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "det_pose_output.jpg", "Output image path"); +DEFINE_string(skeleton, "coco", R"(Path to skeleton data or name of predefined skeletons: "coco")"); + +DEFINE_int32(det_label, 0, "Detection label use for pose estimation"); +DEFINE_double(det_thr, .5, "Detection score threshold"); +DEFINE_double(det_min_bbox_size, -1, "Detection minimum bbox size"); + +DEFINE_double(pose_thr, 0, "Pose key-point threshold"); + +int main(int argc, char* argv[]) { + if (!utils::ParseArguments(argc, argv)) { + return -1; + } + + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; + } + + mmdeploy::Device device{FLAGS_device}; + // create object detector + mmdeploy::Detector detector(mmdeploy::Model(ARGS_det_model), device); + // create pose detector + mmdeploy::PoseDetector pose(mmdeploy::Model(ARGS_pose_model), device); + + // apply the detector, the result is an array-like class holding references to + // `mmdeploy_detection_t`, will be released automatically on destruction + mmdeploy::Detector::Result dets = detector.Apply(img); + + // filter detections and extract bboxes for pose model + std::vector bboxes; + for (const mmdeploy_detection_t& det : dets) { + if (det.label_id == FLAGS_det_label && det.score > FLAGS_det_thr) { + bboxes.push_back(det.bbox); + } + } + + // apply pose detector, if no bboxes are provided, full image will be used; the result is an + // array-like class holding references to `mmdeploy_pose_detection_t`, will be released + // automatically on destruction + mmdeploy::PoseDetector::Result poses = pose.Apply(img, bboxes); + + assert(bboxes.size() == poses.size()); + + // visualize results + utils::Visualize v; + v.set_skeleton(utils::Skeleton::get(FLAGS_skeleton)); + auto sess = v.get_session(img); + for (size_t i = 0; i < bboxes.size(); ++i) { + sess.add_bbox(bboxes[i], -1, -1); + sess.add_pose(poses[i].point, poses[i].score, poses[i].length, FLAGS_pose_thr); + } + + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); + } + + return 0; +} diff --git a/demo/csrc/cpp/detector.cxx b/demo/csrc/cpp/detector.cxx index 7009d2fd21..e5b9f58183 100644 --- a/demo/csrc/cpp/detector.cxx +++ b/demo/csrc/cpp/detector.cxx @@ -1,69 +1,47 @@ #include "mmdeploy/detector.hpp" -#include -#include -#include +#include "opencv2/imgcodecs/imgcodecs.hpp" +#include "utils/argparse.h" +#include "utils/visualize.h" -int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n object_detection device_name model_path image_path\n"); - return 1; - } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; - } - - mmdeploy::Model model(model_path); - mmdeploy::Detector detector(model, mmdeploy::Device{device_name, 0}); - - auto dets = detector.Apply(img); - - fprintf(stdout, "bbox_count=%d\n", (int)dets.size()); - - for (int i = 0; i < dets.size(); ++i) { - const auto& box = dets[i].bbox; - const auto& mask = dets[i].mask; +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "detector_output.jpg", "Output image path"); - fprintf(stdout, "box %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n", - i, box.left, box.top, box.right, box.bottom, dets[i].label_id, dets[i].score); +DEFINE_double(det_thr, .5, "Detection score threshold"); - // skip detections with invalid bbox size (bbox height or width < 1) - if ((box.right - box.left) < 1 || (box.bottom - box.top) < 1) { - continue; - } +int main(int argc, char* argv[]) { + if (!utils::ParseArguments(argc, argv)) { + return -1; + } - // skip detections less than specified score threshold - if (dets[i].score < 0.3) { - continue; - } + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; + } - // generate mask overlay if model exports masks - if (mask != nullptr) { - fprintf(stdout, "mask %d, height=%d, width=%d\n", i, mask->height, mask->width); + // construct a detector instance + mmdeploy::Detector detector(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}); - cv::Mat imgMask(mask->height, mask->width, CV_8UC1, &mask->data[0]); - auto x0 = std::max(std::floor(box.left) - 1, 0.f); - auto y0 = std::max(std::floor(box.top) - 1, 0.f); - cv::Rect roi((int)x0, (int)y0, mask->width, mask->height); + // apply the detector, the result is an array-like class holding references to + // `mmdeploy_detection_t`, will be released automatically on destruction + mmdeploy::Detector::Result dets = detector.Apply(img); - // split the RGB channels, overlay mask to a specific color channel - cv::Mat ch[3]; - split(img, ch); - int col = 0; // int col = i % 3; - cv::bitwise_or(imgMask, ch[col](roi), ch[col](roi)); - merge(ch, 3, img); + // visualize + utils::Visualize v; + auto sess = v.get_session(img); + int count = 0; + for (const mmdeploy_detection_t& det : dets) { + if (det.score > FLAGS_det_thr) { // filter bboxes + sess.add_det(det.bbox, det.label_id, det.score, det.mask, count++); } - - cv::rectangle(img, cv::Point{(int)box.left, (int)box.top}, - cv::Point{(int)box.right, (int)box.bottom}, cv::Scalar{0, 255, 0}); } - cv::imwrite("output_detection.png", img); + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); + } return 0; } diff --git a/demo/csrc/cpp/pose_tracker.cpp b/demo/csrc/cpp/pose_tracker.cpp deleted file mode 100644 index 077b3d39ea..0000000000 --- a/demo/csrc/cpp/pose_tracker.cpp +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "mmdeploy/pose_tracker.hpp" - -#include - -#include "opencv2/highgui/highgui.hpp" -#include "opencv2/imgcodecs/imgcodecs.hpp" -#include "opencv2/imgproc/imgproc.hpp" -#include "opencv2/videoio/videoio.hpp" - -struct Args { - std::string device; - std::string det_model; - std::string pose_model; - std::string video; - std::string output_dir; -}; - -Args ParseArgs(int argc, char* argv[]); - -using std::vector; -using namespace mmdeploy; - -bool Visualize(cv::Mat frame, const PoseTracker::Result& result, int size, - const std::string& output_dir, int frame_id, bool with_bbox); - -int main(int argc, char* argv[]) { - auto args = ParseArgs(argc, argv); - if (args.device.empty()) { - return 0; - } - - // create pose tracker pipeline - PoseTracker tracker(Model(args.det_model), Model(args.pose_model), Context{Device{args.device}}); - - // set parameters - PoseTracker::Params params; - params->det_min_bbox_size = 100; - params->det_interval = 1; - params->pose_max_num_bboxes = 6; - - // optionally use OKS for keypoints similarity comparison - std::array sigmas{0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, - 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089}; - params->keypoint_sigmas = sigmas.data(); - params->keypoint_sigmas_size = sigmas.size(); - - // create a tracker state for each video - PoseTracker::State state = tracker.CreateState(params); - - cv::VideoCapture video; - if (args.video.size() == 1 && std::isdigit(args.video[0])) { - video.open(std::stoi(args.video)); // open by camera index - } else { - video.open(args.video); // open video file - } - if (!video.isOpened()) { - std::cerr << "failed to open video: " << args.video << "\n"; - } - - cv::Mat frame; - int frame_id = 0; - while (true) { - video >> frame; - if (!frame.data) { - break; - } - // apply the pipeline with the tracker state and video frame - auto result = tracker.Apply(state, frame); - // visualize the results - if (!Visualize(frame, result, 1280, args.output_dir, frame_id++, false)) { - break; - } - } - - return 0; -} - -struct Skeleton { - vector> skeleton; - vector palette; - vector link_color; - vector point_color; -}; - -const Skeleton& gCocoSkeleton() { - static const Skeleton inst{ - { - {15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12}, - {5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1}, - {0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}, - }, - { - {255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255}, - {153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255}, - {255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102}, - {51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255}, - }, - {0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16}, - {16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0}, - }; - return inst; -} - -bool Visualize(cv::Mat frame, const PoseTracker::Result& result, int size, - const std::string& output_dir, int frame_id, bool with_bbox) { - auto& [skeleton, palette, link_color, point_color] = gCocoSkeleton(); - auto scale = (float)size / (float)std::max(frame.cols, frame.rows); - if (scale != 1) { - cv::resize(frame, frame, {}, scale, scale); - } else { - frame = frame.clone(); - } - auto draw_bbox = [&](std::array bbox, const cv::Scalar& color) { - std::for_each(bbox.begin(), bbox.end(), [&](auto& x) { x *= scale; }); - cv::rectangle(frame, cv::Point2f(bbox[0], bbox[1]), cv::Point2f(bbox[2], bbox[3]), color); - }; - for (const auto& r : result) { - vector kpts(&r.keypoints[0].x, &r.keypoints[0].x + r.keypoint_count * 2); - vector scores(r.scores, r.scores + r.keypoint_count); - std::for_each(kpts.begin(), kpts.end(), [&](auto& x) { x *= scale; }); - constexpr auto score_thr = .5f; - vector used(kpts.size()); - for (size_t i = 0; i < skeleton.size(); ++i) { - auto [u, v] = skeleton[i]; - if (scores[u] > score_thr && scores[v] > score_thr) { - used[u] = used[v] = 1; - cv::Point2f p_u(kpts[u * 2], kpts[u * 2 + 1]); - cv::Point2f p_v(kpts[v * 2], kpts[v * 2 + 1]); - cv::line(frame, p_u, p_v, palette[link_color[i]], 1, cv::LINE_AA); - } - } - for (size_t i = 0; i < kpts.size(); i += 2) { - if (used[i / 2]) { - cv::Point2f p(kpts[i], kpts[i + 1]); - cv::circle(frame, p, 1, palette[point_color[i / 2]], 2, cv::LINE_AA); - } - } - if (with_bbox) { - draw_bbox((std::array&)r.bbox, cv::Scalar(0, 255, 0)); - } - } - if (output_dir.empty()) { - cv::imshow("pose_tracker", frame); - return cv::waitKey(1) != 'q'; - } - auto join = [](const std::string& a, const std::string& b) { -#if _MSC_VER - return a + "\\" + b; -#else - return a + "/" + b; -#endif - }; - cv::imwrite(join(output_dir, std::to_string(frame_id) + ".jpg"), frame, - {cv::IMWRITE_JPEG_QUALITY, 90}); - return true; -} - -Args ParseArgs(int argc, char* argv[]) { - if (argc < 5 || argc > 6) { - std::cout << R"(Usage: pose_tracker device_name det_model pose_model video [output] - device_name device name for execution, e.g. "cpu", "cuda" - det_model object detection model path - pose_model pose estimation model path - video video path or camera index - output output directory, will cv::imshow if omitted -)"; - return {}; - } - Args args; - args.device = argv[1]; - args.det_model = argv[2]; - args.pose_model = argv[3]; - args.video = argv[4]; - if (argc == 6) { - args.output_dir = argv[5]; - } - return args; -} diff --git a/demo/csrc/cpp/pose_tracker.cxx b/demo/csrc/cpp/pose_tracker.cxx new file mode 100644 index 0000000000..c50bdfbe14 --- /dev/null +++ b/demo/csrc/cpp/pose_tracker.cxx @@ -0,0 +1,69 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "mmdeploy/pose_tracker.hpp" + +#include "utils/argparse.h" +#include "utils/mediaio.h" +#include "utils/visualize.h" + +DEFINE_ARG_string(det_model, "Object detection model path"); +DEFINE_ARG_string(pose_model, "Pose estimation model path"); +DEFINE_ARG_string(input, "Input video path or camera index"); + +DEFINE_string(device, "cpu", "Device name, e.g. \"cpu\", \"cuda\""); +DEFINE_string(output, "", "Output video path or format string"); + +DEFINE_int32(output_size, 0, "Long-edge of output frames"); +DEFINE_int32(flip, 0, "Set to 1 for flipping the input horizontally"); +DEFINE_int32(show, 1, "Delay passed to `cv::waitKey` when using `cv::imshow`; -1: disable"); + +DEFINE_string(skeleton, "coco", R"(Path to skeleton data or name of predefined skeletons: "coco")"); +DEFINE_string(background, "default", + R"(Output background, "default": original image, "black": black background)"); + +#include "pose_tracker_params.h" + +int main(int argc, char* argv[]) { + if (!utils::ParseArguments(argc, argv)) { + return -1; + } + + // create pose tracker pipeline + mmdeploy::PoseTracker tracker(mmdeploy::Model(ARGS_det_model), mmdeploy::Model(ARGS_pose_model), + mmdeploy::Device{FLAGS_device}); + + mmdeploy::PoseTracker::Params params; + // initialize tracker params with program arguments + InitTrackerParams(params); + + // create a tracker state for each video + mmdeploy::PoseTracker::State state = tracker.CreateState(params); + + utils::mediaio::Input input(ARGS_input, FLAGS_flip); + utils::mediaio::Output output(FLAGS_output, FLAGS_show); + + utils::Visualize v(FLAGS_output_size); + v.set_background(FLAGS_background); + v.set_skeleton(utils::Skeleton::get(FLAGS_skeleton)); + + for (const cv::Mat& frame : input) { + // apply the pipeline with the tracker state and video frame; the result is an array-like class + // holding references to `mmdeploy_pose_tracker_target_t`, will be released automatically on + // destruction + mmdeploy::PoseTracker::Result result = tracker.Apply(state, frame); + + // visualize results + auto sess = v.get_session(frame); + for (const mmdeploy_pose_tracker_target_t& target : result) { + sess.add_pose(target.keypoints, target.scores, target.keypoint_count, FLAGS_pose_kpt_thr); + } + + // write to output stream + if (!output.write(sess.get())) { + // user requested exit by pressing ESC + break; + } + } + + return 0; +} diff --git a/demo/csrc/cpp/pose_tracker_params.h b/demo/csrc/cpp/pose_tracker_params.h new file mode 100644 index 0000000000..2cda301869 --- /dev/null +++ b/demo/csrc/cpp/pose_tracker_params.h @@ -0,0 +1,39 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +DEFINE_int32(det_interval, 1, "Detection interval"); +DEFINE_int32(det_label, 0, "Detection label use for pose estimation"); +DEFINE_double(det_thr, 0.5, "Detection score threshold"); +DEFINE_double(det_min_bbox_size, -1, "Detection minimum bbox size"); +DEFINE_double(det_nms_thr, .7, + "NMS IOU threshold for merging detected bboxes and bboxes from tracked targets"); + +DEFINE_int32(pose_max_num_bboxes, -1, "Max number of bboxes used for pose estimation per frame"); +DEFINE_double(pose_kpt_thr, .5, "Threshold for visible key-points"); +DEFINE_int32(pose_min_keypoints, -1, + "Min number of key-points for valid poses, -1 indicates ceil(n_kpts/2)"); +DEFINE_double(pose_bbox_scale, 1.25, "Scale for expanding key-points to bbox"); +DEFINE_double( + pose_min_bbox_size, -1, + "Min pose bbox size, tracks with bbox size smaller than the threshold will be dropped"); +DEFINE_double(pose_nms_thr, 0.5, + "NMS OKS/IOU threshold for suppressing overlapped poses, useful when multiple pose " + "estimations collapse to the same target"); + +DEFINE_double(track_iou_thr, 0.4, "IOU threshold for associating missing tracks"); +DEFINE_int32(track_max_missing, 10, + "Max number of missing frames before a missing tracks is removed"); + +void InitTrackerParams(mmdeploy::PoseTracker::Params& params) { + params->det_interval = FLAGS_det_interval; + params->det_label = FLAGS_det_label; + params->det_thr = FLAGS_det_thr; + params->det_min_bbox_size = FLAGS_det_min_bbox_size; + params->pose_max_num_bboxes = FLAGS_pose_max_num_bboxes; + params->pose_kpt_thr = FLAGS_pose_kpt_thr; + params->pose_min_keypoints = FLAGS_pose_min_keypoints; + params->pose_bbox_scale = FLAGS_pose_bbox_scale; + params->pose_min_bbox_size = FLAGS_pose_min_bbox_size; + params->pose_nms_thr = FLAGS_pose_nms_thr; + params->track_iou_thr = FLAGS_track_iou_thr; + params->track_max_missing = FLAGS_track_max_missing; +} diff --git a/demo/csrc/cpp/restorer.cxx b/demo/csrc/cpp/restorer.cxx index 7c3eefd82c..378294fed4 100644 --- a/demo/csrc/cpp/restorer.cxx +++ b/demo/csrc/cpp/restorer.cxx @@ -2,33 +2,39 @@ #include "mmdeploy/restorer.hpp" -#include -#include -#include +#include "opencv2/imgcodecs/imgcodecs.hpp" +#include "opencv2/imgproc/imgproc.hpp" +#include "utils/argparse.h" + +DEFINE_ARG_string(model, "Super-resolution model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "restorer_output.jpg", "Output image path"); int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n image_restorer device_name model_path image_path\n"); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; + + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; } - using namespace mmdeploy; + // construct a restorer instance + mmdeploy::Restorer restorer{mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}}; - Restorer restorer{Model{model_path}, Device{device_name}}; + // apply restorer to the image + mmdeploy::Restorer::Result result = restorer.Apply(img); - auto result = restorer.Apply(img); + // convert to BGR + cv::Mat upsampled(result->height, result->width, CV_8UC3, result->data); + cv::cvtColor(upsampled, upsampled, cv::COLOR_RGB2BGR); - cv::Mat sr_img(result->height, result->width, CV_8UC3, result->data); - cv::cvtColor(sr_img, sr_img, cv::COLOR_RGB2BGR); - cv::imwrite("output_restorer.bmp", sr_img); + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, upsampled); + } return 0; } diff --git a/demo/csrc/cpp/rotated_detector.cxx b/demo/csrc/cpp/rotated_detector.cxx index d590273d32..3ddb5d4c1c 100644 --- a/demo/csrc/cpp/rotated_detector.cxx +++ b/demo/csrc/cpp/rotated_detector.cxx @@ -1,51 +1,46 @@ #include "mmdeploy/rotated_detector.hpp" -#include -#include -#include +#include "utils/argparse.h" +#include "utils/visualize.h" + +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "rotated_detector_output.jpg", "Output image path"); + +DEFINE_double(det_thr, 0.1, "Detection score threshold"); int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n oriented_object_detection device_name model_path image_path\n"); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; + + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; } - mmdeploy::Model model(model_path); - mmdeploy::RotatedDetector detector(model, mmdeploy::Device{device_name, 0}); + // construct a detector instance + mmdeploy::RotatedDetector detector(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}); - auto dets = detector.Apply(img); + // apply the detector, the result is an array-like class holding references to + // `mmdeploy_rotated_detection_t`, will be released automatically on destruction + mmdeploy::RotatedDetector::Result dets = detector.Apply(img); - for (const auto& det : dets) { - if (det.score < 0.1) { - continue; + // visualize results + utils::Visualize v; + auto sess = v.get_session(img); + for (const mmdeploy_rotated_detection_t& det : dets) { + if (det.score > FLAGS_det_thr) { + sess.add_rotated_det(det.rbbox, det.label_id, det.score); } - auto& rbbox = det.rbbox; - float xc = rbbox[0]; - float yc = rbbox[1]; - float w = rbbox[2]; - float h = rbbox[3]; - float ag = rbbox[4]; - float wx = w / 2 * std::cos(ag); - float wy = w / 2 * std::sin(ag); - float hx = -h / 2 * std::sin(ag); - float hy = h / 2 * std::cos(ag); - cv::Point p1 = {int(xc - wx - hx), int(yc - wy - hy)}; - cv::Point p2 = {int(xc + wx - hx), int(yc + wy - hy)}; - cv::Point p3 = {int(xc + wx + hx), int(yc + wy + hy)}; - cv::Point p4 = {int(xc - wx + hx), int(yc - wy + hy)}; - cv::drawContours(img, std::vector>{{p1, p2, p3, p4}}, -1, {0, 255, 0}, - 2); } - cv::imwrite("output_rotated_detection.png", img); + + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); + } return 0; } diff --git a/demo/csrc/cpp/segmentor.cxx b/demo/csrc/cpp/segmentor.cxx index 0c1dde49d6..55ba41db67 100644 --- a/demo/csrc/cpp/segmentor.cxx +++ b/demo/csrc/cpp/segmentor.cxx @@ -2,57 +2,46 @@ #include "mmdeploy/segmentor.hpp" -#include -#include -#include -#include #include #include -using namespace std; +#include "utils/argparse.h" +#include "utils/mediaio.h" +#include "utils/visualize.h" -vector gen_palette(int num_classes) { - std::mt19937 gen; - std::uniform_int_distribution uniform_dist(0, 255); - - vector palette; - palette.reserve(num_classes); - for (auto i = 0; i < num_classes; ++i) { - palette.emplace_back(uniform_dist(gen), uniform_dist(gen), uniform_dist(gen)); - } - return palette; -} +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "segmentor_output.jpg", "Output image path"); +DEFINE_string(palette, "cityscapes", + R"(Path to palette data or name of predefined palettes: "cityscapes")"); int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n image_segmentation device_name model_path image_path\n"); - return 1; - } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - using namespace mmdeploy; + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; + } - Segmentor segmentor{Model{model_path}, Device{device_name}}; + mmdeploy::Segmentor segmentor{mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}}; - auto result = segmentor.Apply(img); + // apply the detector, the result is an array-like class holding a reference to + // `mmdeploy_segmentation_t`, will be released automatically on destruction + mmdeploy::Segmentor::Result seg = segmentor.Apply(img); - auto palette = gen_palette(result->classes + 1); + // visualize + utils::Visualize v; + v.set_palette(utils::Palette::get(FLAGS_palette)); + auto sess = v.get_session(img); + sess.add_mask(seg->height, seg->width, seg->classes, seg->mask, seg->score); - cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3); - int pos = 0; - for (auto iter = color_mask.begin(); iter != color_mask.end(); ++iter) { - *iter = palette[result->mask[pos++]]; + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); } - img = img * 0.5 + color_mask * 0.5; - cv::imwrite("output_segmentation.png", img); - return 0; } diff --git a/demo/csrc/cpp/text_ocr.cxx b/demo/csrc/cpp/text_ocr.cxx index 853d681071..6c8fdb055b 100644 --- a/demo/csrc/cpp/text_ocr.cxx +++ b/demo/csrc/cpp/text_ocr.cxx @@ -1,46 +1,57 @@ -#include -#include #include #include "mmdeploy/text_detector.hpp" #include "mmdeploy/text_recognizer.hpp" +#include "utils/argparse.h" +#include "utils/mediaio.h" +#include "utils/visualize.h" -int main(int argc, char* argv[]) { - if (argc != 5) { - fprintf(stderr, "usage:\n ocr device_name det_model_path reg_model_path image_path\n"); - return 1; - } - const auto device_name = argv[1]; - auto det_model_path = argv[2]; - auto reg_model_path = argv[3]; - auto image_path = argv[4]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; - } +DEFINE_ARG_string(det_model, "Text detection model path"); +DEFINE_ARG_string(reg_model, "Text recognition model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "text_ocr_output.jpg", "Output image path"); - using namespace mmdeploy; +using mmdeploy::TextDetector; +using mmdeploy::TextRecognizer; - TextDetector detector{Model(det_model_path), Device(device_name)}; - TextRecognizer recognizer{Model(reg_model_path), Device(device_name)}; +int main(int argc, char* argv[]) { + if (!utils::ParseArguments(argc, argv)) { + return -1; + } - auto bboxes = detector.Apply(img); - auto texts = recognizer.Apply(img, {bboxes.begin(), bboxes.size()}); + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; + } - for (int i = 0; i < bboxes.size(); ++i) { - fprintf(stdout, "box[%d]: %s\n", i, texts[i].text); - std::vector poly_points; - for (const auto& pt : bboxes[i].bbox) { - fprintf(stdout, "x: %.2f, y: %.2f, ", pt.x, pt.y); - poly_points.emplace_back((int)pt.x, (int)pt.y); - } - fprintf(stdout, "\n"); - cv::polylines(img, poly_points, true, cv::Scalar{0, 255, 0}); + mmdeploy::Device device(FLAGS_device); + TextDetector detector{mmdeploy::Model(ARGS_det_model), device}; + TextRecognizer recognizer{mmdeploy::Model(ARGS_reg_model), device}; + + // apply the detector, the result is an array-like class holding references to + // `mmdeploy_text_detection_t`, will be released automatically on destruction + TextDetector::Result bboxes = detector.Apply(img); + + // apply recognizer, if no bboxes are provided, full image will be used; the result is an + // array-like class holding references to `mmdeploy_text_recognition_t`, will be released + // automatically on destruction + TextRecognizer::Result texts = recognizer.Apply(img, {bboxes.begin(), bboxes.size()}); + + // visualize results + utils::Visualize v; + auto sess = v.get_session(img); + for (size_t i = 0; i < bboxes.size(); ++i) { + mmdeploy_text_detection_t& bbox = bboxes[i]; + mmdeploy_text_recognition_t& text = texts[i]; + sess.add_text_det(bbox.bbox, bbox.score, text.text, text.length, i); } - cv::imwrite("output_ocr.png", img); + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); + } return 0; } diff --git a/demo/csrc/cpp/utils/argparse.h b/demo/csrc/cpp/utils/argparse.h new file mode 100644 index 0000000000..5c94c8afad --- /dev/null +++ b/demo/csrc/cpp/utils/argparse.h @@ -0,0 +1,272 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_ARGPARSE_H +#define MMDEPLOY_ARGPARSE_H + +#include +#include +#include +#include +#include +#include +#include + +#define DEFINE_int32(name, init, msg) _MMDEPLOY_DEFINE_FLAG(int32_t, name, init, msg) +#define DEFINE_double(name, init, msg) _MMDEPLOY_DEFINE_FLAG(double, name, init, msg) +#define DEFINE_string(name, init, msg) _MMDEPLOY_DEFINE_FLAG(std::string, name, init, msg) + +#define DEFINE_ARG_int32(name, msg) _MMDEPLOY_DEFINE_ARG(int32_t, name, msg) +#define DEFINE_ARG_double(name, msg) _MMDEPLOY_DEFINE_ARG(double, name, msg) +#define DEFINE_ARG_string(name, msg) _MMDEPLOY_DEFINE_ARG(std::string, name, msg) + +namespace utils { + +class ArgParse { + public: + template + static T Register(const std::string& type, const std::string& name, T init, + const std::string& msg, void* ptr) { + instance()._Register(type, name, msg, true, ptr); + return init; + } + + template + static T Register(const std::string& type, const std::string& name, const std::string& msg, + void* ptr) { + instance()._Register(type, name, msg, false, ptr); + return {}; + } + + static bool ParseArguments(int argc, char* argv[]) { + if (!instance()._Parse(argc, argv)) { + ShowUsageWithFlags(argv[0]); + return false; + } + return true; + } + + static void ShowUsageWithFlags(const char* argv0) { instance()._ShowUsageWithFlags(argv0); } + + private: + static ArgParse& instance() { + static ArgParse inst; + return inst; + } + + struct Info { + std::string name; + std::string type; + std::string msg; + bool is_flag; + void* ptr; + }; + + void _Register(std::string type, const std::string& name, const std::string& msg, bool is_flag, + void* ptr) { + if (type == "std::string") { + type = "string"; + } else if (type == "int32_t") { + type = "int32"; + } + infos_.push_back({name, type, msg, is_flag, ptr}); + } + + bool _Parse(int argc, char* argv[]) { + int arg_idx{-1}; + std::vector args(infos_.size()); + std::vector used(infos_.size()); + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + return false; + } + if (argv[i][0] == '-' && argv[i][1] == '-') { + // parse flag key-value pair (--x=y or --x y) + int eq{-1}; + for (int k = 2; argv[i][k]; ++k) { + if (argv[i][k] == '=') { + eq = k; + break; + } + } + std::string key; + std::string val; + if (eq >= 0) { + key = std::string(argv[i] + 2, argv[i] + eq); + val = std::string(argv[i] + eq + 1); + } else { + key = std::string(argv[i] + 2); + if (i < argc - 1) { + val = argv[++i]; + } + } + bool found{}; + for (int j = 0; j < infos_.size(); ++j) { + auto& flag = infos_[j]; + if (key == flag.name) { + args[j] = val; + found = used[j] = 1; + break; + } + } + if (!found) { + std::cout << "error: unknown option: " << key << std::endl; + return false; + } + } else { + for (arg_idx++; arg_idx < infos_.size(); ++arg_idx) { + if (!infos_[arg_idx].is_flag) { + args[arg_idx] = argv[i]; + used[arg_idx] = 1; + break; + } + } + if (arg_idx == infos_.size()) { + std::cout << "error: unknown argument: " << argv[i] << std::endl; + return false; + } + } + } + std::vector missing; + for (arg_idx++; arg_idx < infos_.size(); ++arg_idx) { + if (!infos_[arg_idx].is_flag) { + missing.push_back(infos_[arg_idx].name); + } + } + if (!missing.empty()) { + std::cout << "error: the following arguments are required:"; + for (int i = 0; i < missing.size(); ++i) { + std::cout << " " << missing[i]; + if (i != missing.size() - 1) { + std::cout << ","; + } + } + std::cout << "\n"; + return false; + } + + for (int i = 0; i < infos_.size(); ++i) { + if (used[i]) { + try { + parse_str(infos_[i], args[i]); + } catch (...) { + std::cout << "error: failed to parse " << infos_[i].name << ": " << args[i] << std::endl; + return false; + } + } + } + + return true; + } + + static void parse_str(Info& info, const std::string& str) { + if (info.type == "int32") { + *static_cast(info.ptr) = std::stoi(str); + } else if (info.type == "double") { + *static_cast(info.ptr) = std::stod(str); + } else if (info.type == "string") { + *static_cast(info.ptr) = str; + } else { + // pass + } + } + + static std::string get_default_str(const Info& info) { + if (info.type == "int32") { + return std::to_string(*static_cast(info.ptr)); + } else if (info.type == "double") { + std::ostringstream os; + os << std::setprecision(3) << *static_cast(info.ptr); + return os.str(); + } else if (info.type == "string") { + return "\"" + *(static_cast(info.ptr)) + "\""; + } else { + return ""; + } + } + + void _ShowUsageWithFlags(const char* argv0) const { + ShowUsage(argv0); + static constexpr const auto kLineLength = 80; + std::cout << std::endl; + int max_name_length = 0; + for (const auto& info : infos_) { + max_name_length = std::max(max_name_length, (int)info.name.length()); + } + max_name_length += 4; + auto name_col_size = max_name_length + 1; + auto msg_col_size = kLineLength - name_col_size; + std::cout << "required arguments:\n"; + ShowFlags(name_col_size, msg_col_size, false); + std::cout << std::endl; + std::cout << "optional arguments:\n"; + ShowFlags(name_col_size, msg_col_size, true); + } + + void ShowFlags(int name_col_size, int msg_col_size, bool is_flag) const { + for (const auto& info : infos_) { + if (info.is_flag != is_flag) { + continue; + } + std::string name = " "; + if (info.is_flag) { + name.append("--"); + } + name.append(info.name); + while (name.length() < name_col_size) { + name.append(" "); + } + std::cout << name; + std::string msg = info.msg; + while (msg.length() > msg_col_size) { // insert line-breaks when msg is too long + auto pos = msg.rend() - std::find(std::make_reverse_iterator(msg.begin() + msg_col_size), + msg.rend(), ' '); + std::cout << msg.substr(0, pos - 1) << std::endl; + std::cout << std::string(name_col_size, ' '); + msg = msg.substr(pos); + } + std::cout << msg; + std::string type; + type.append("[").append(info.type); + if (info.is_flag) { + type.append(" = ").append(get_default_str(info)); + } + type.append("]"); + if (msg.length() + type.length() + 1 > msg_col_size) { + std::cout << std::endl << std::string(name_col_size, ' ') << type; + } else { + std::cout << " " << type; + } + std::cout << std::endl; + } + } + + void ShowUsage(const char* argv0) const { + for (auto p = argv0; *p; ++p) { + if (*p == '/' || *p == '\'') { + argv0 = p + 1; + } + } + std::cout << "Usage: " << argv0 << " [options]"; + for (const auto& info : infos_) { + if (!info.is_flag) { + std::cout << " " << info.name; + } + } + std::cout << std::endl; + } + + private: + std::vector infos_; +}; + +inline bool ParseArguments(int argc, char* argv[]) { return ArgParse::ParseArguments(argc, argv); } + +} // namespace utils + +#define _MMDEPLOY_DEFINE_FLAG(type, name, init, msg) \ + type FLAGS_##name = ::utils::ArgParse::Register(#type, #name, type(init), msg, &FLAGS_##name) + +#define _MMDEPLOY_DEFINE_ARG(type, name, msg) \ + type ARGS_##name = ::utils::ArgParse::Register(#type, #name, msg, &ARGS_##name) + +#endif // MMDEPLOY_ARGPARSE_H diff --git a/demo/csrc/cpp/utils/mediaio.h b/demo/csrc/cpp/utils/mediaio.h new file mode 100644 index 0000000000..65018602c2 --- /dev/null +++ b/demo/csrc/cpp/utils/mediaio.h @@ -0,0 +1,393 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_MEDIAIO_H +#define MMDEPLOY_MEDIAIO_H + +#include +#include + +#include "opencv2/highgui/highgui.hpp" +#include "opencv2/imgcodecs/imgcodecs.hpp" +#include "opencv2/videoio/videoio.hpp" + +namespace utils { +namespace mediaio { + +enum class MediaType { kUnknown, kImage, kVideo, kImageList, kWebcam, kFmtStr, kDisable }; + +namespace detail { + +static std::string get_extension(const std::string& path) { + std::string ext; + for (auto i = (int)path.size() - 1; i >= 0; --i) { + if (path[i] == '.') { + ext.push_back(path[i]); + for (++i; i < path.size(); ++i) { + ext.push_back((char)std::tolower((unsigned char)path[i])); + } + return ext; + } + } + return {}; +} + +int ext2fourcc(const std::string& ext) { + auto get_fourcc = [](const char* s) { return cv::VideoWriter::fourcc(s[0], s[1], s[2], s[3]); }; + static std::map ext2fourcc{ + {".mp4", get_fourcc("mp4v")}, + {".avi", get_fourcc("DIVX")}, + {".mkv", get_fourcc("X264")}, + {".wmv", get_fourcc("WMV3")}, + }; + auto it = ext2fourcc.find(ext); + if (it != ext2fourcc.end()) { + return it->second; + } + return get_fourcc("DIVX"); +} + +static bool is_video(const std::string& ext) { + static const std::set es{".mp4", ".avi", ".mkv", ".webm", ".mov", ".mpg", ".wmv"}; + return es.count(ext); +} + +static bool is_list(const std::string& ext) { + static const std::set es{".txt"}; + return es.count(ext); +} + +static bool is_image(const std::string& ext) { + static const std::set es{".jpg", ".jpeg", ".png", ".tif", ".tiff", + ".bmp", ".ppm", ".pgm", ".webp"}; + return es.count(ext); +} + +static bool is_fmtstr(const std::string& str) { + for (const auto& c : str) { + if (c == '%') { + return true; + } + } + return false; +} + +} // namespace detail + +class Input; + +class InputIterator { + public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using reference = cv::Mat&; + using value_type = reference; + using pointer = void; + + public: + InputIterator() = default; + explicit InputIterator(Input& input) : input_(&input) { next(); } + InputIterator& operator++() { + next(); + return *this; + } + reference operator*() { return frame_; } + friend bool operator==(const InputIterator& a, const InputIterator& b) { + return &a == &b || a.is_end() == b.is_end(); + } + friend bool operator!=(const InputIterator& a, const InputIterator& b) { return !(a == b); } + + private: + void next(); + bool is_end() const noexcept { return frame_.data != nullptr; } + + private: + cv::Mat frame_; + Input* input_{}; +}; + +class BatchInputIterator { + public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using reference = std::vector&; + using value_type = reference; + using pointer = void; + + public: + BatchInputIterator() = default; + BatchInputIterator(InputIterator iter, InputIterator end, size_t batch_size) + : iter_(std::move(iter)), end_(std::move(end)), batch_size_(batch_size) { + next(); + } + + BatchInputIterator& operator++() { + next(); + return *this; + } + + reference operator*() { return data_; } + + friend bool operator==(const BatchInputIterator& a, const BatchInputIterator& b) { + return &a == &b || a.is_end() == b.is_end(); + } + + friend bool operator!=(const BatchInputIterator& a, const BatchInputIterator& b) { + return !(a == b); + } + + private: + void next() { + data_.clear(); + for (size_t i = 0; i < batch_size_ && iter_ != end_; ++i, ++iter_) { + data_.push_back(*iter_); + } + } + + bool is_end() const { return data_.empty(); } + + private: + InputIterator iter_; + InputIterator end_; + size_t batch_size_{1}; + std::vector data_; +}; + +class Input { + public: + explicit Input(const std::string& path, bool flip = false, MediaType type = MediaType::kUnknown) + : path_(path), flip_(flip), type_(type) { + if (type_ == MediaType::kUnknown) { + auto ext = detail::get_extension(path); + if (detail::is_image(ext)) { + type_ = MediaType::kImage; + } else if (detail::is_video(ext)) { + type_ = MediaType::kVideo; + } else if (path.size() == 1 && std::isdigit((unsigned char)path[0])) { + type_ = MediaType::kWebcam; + } else if (detail::is_list(ext) || try_image_list(path)) { + type_ = MediaType::kImageList; + } else if (try_image(path)) { + type_ = MediaType::kImage; + } else if (try_video(path)) { + type_ = MediaType::kVideo; + } else { + std::cout << "unknown file type: " << path << "\n"; + } + } + if (type_ != MediaType::kUnknown) { + if (type_ == MediaType::kVideo) { + cap_.open(path_); + if (!cap_.isOpened()) { + std::cerr << "failed to open video file: " << path_ << "\n"; + } + } else if (type_ == MediaType::kWebcam) { + cap_.open(std::stoi(path_)); + if (!cap_.isOpened()) { + std::cerr << "failed to open camera index: " << path_ << "\n"; + } + type_ = MediaType::kVideo; + } else if (type_ == MediaType::kImage) { + items_ = {path_}; + type_ = MediaType::kImageList; + } else if (type_ == MediaType::kImageList) { + if (items_.empty()) { + items_ = load_image_list(path); + } + } + } + } + InputIterator begin() { return InputIterator(*this); } + InputIterator end() { return {}; } // NOLINT + + cv::Mat read() { + cv::Mat img; + if (type_ == MediaType::kVideo) { + cap_ >> img; + } else if (type_ == MediaType::kImageList) { + while (!img.data && index_ < items_.size()) { + auto path = items_[index_++]; + img = cv::imread(path); + if (!img.data) { + std::cerr << "failed to load image: " << path << "\n"; + } + } + } + if (flip_ && !img.empty()) { + cv::flip(img, img, 1); + } + return img; + } + + class Batch { + public: + Batch(Input& input, size_t batch_size) : input_(&input), batch_size_(batch_size) {} + BatchInputIterator begin() { return {input_->begin(), input_->end(), batch_size_}; } + BatchInputIterator end() { return {}; } // NOLINT + + private: + Input* input_{}; + size_t batch_size_{1}; + }; + + Batch batch(size_t batch_size) { return {*this, batch_size}; } + + private: + static bool try_image(const std::string& path) { return cv::imread(path).data; } + + static bool try_video(const std::string& path) { return cv::VideoCapture(path).isOpened(); } + + static std::vector load_image_list(const std::string& path, size_t max_bytes = 0) { + std::ifstream ifs(path); + ifs.seekg(0, std::ifstream::end); + auto size = ifs.tellg(); + ifs.seekg(0, std::ifstream::beg); + if (max_bytes && size > max_bytes) { + return {}; + } + auto strip = [](std::string& s) { + while (!s.empty() && std::isspace((unsigned char)s.back())) { + s.pop_back(); + } + }; + std::vector ret; + std::string line; + while (std::getline(ifs, line)) { + strip(line); + if (!line.empty()) { + ret.push_back(std::move(line)); + } + } + return ret; + } + + bool try_image_list(const std::string& path) { + auto items = load_image_list(path, 1 << 20); + size_t count = 0; + for (const auto& item : items) { + if (detail::is_image(detail::get_extension(item)) && ++count > items.size() / 10) { + items_ = std::move(items); + return true; + } + } + return false; + } + + private: + std::string path_; + bool flip_{}; + MediaType type_{MediaType::kUnknown}; + std::vector items_; + cv::VideoCapture cap_; + size_t index_{}; +}; + +inline void InputIterator::next() { + assert(input_); + frame_ = input_->read(); +} + +class Output; + +class OutputIterator { + public: + using iterator_category = std::output_iterator_tag; + using difference_type = std::ptrdiff_t; + using reference = void; + using value_type = void; + using pointer = void; + + public: + explicit OutputIterator(Output& output) : output_(&output) {} + + OutputIterator& operator=(const cv::Mat& frame); + + OutputIterator& operator*() { return *this; } + OutputIterator& operator++() { return *this; } + OutputIterator& operator++(int) { return *this; } // NOLINT + + private: + Output* output_{}; +}; + +class Output { + public: + explicit Output(const std::string& path, int show, MediaType type = MediaType::kUnknown) + : path_(path), type_(type), show_(show) { + ext_ = detail::get_extension(path); + if (type_ == MediaType::kUnknown) { + if (path_.empty()) { + type_ = MediaType::kDisable; + } else if (detail::is_image(ext_)) { + if (detail::is_fmtstr(path)) { + type_ = MediaType::kFmtStr; + } else { + type_ = MediaType::kImage; + } + } else if (detail::is_video(ext_)) { + type_ = MediaType::kVideo; + } else { + std::cout << "unknown file type: " << path << "\n"; + } + } + } + + bool write(const cv::Mat& frame) { + bool exit = false; + switch (type_) { + case MediaType::kDisable: + break; + case MediaType::kImage: + cv::imwrite(path_, frame); + break; + case MediaType::kFmtStr: { + char buf[256]; + snprintf(buf, sizeof(buf), path_.c_str(), frame_id_); + cv::imwrite(buf, frame); + break; + } + case MediaType::kVideo: + write_video(frame); + break; + default: + std::cout << "unsupported output media type\n"; + assert(0); + } + if (show_ >= 0) { + cv::imshow("", frame); + exit = cv::waitKey(show_) == 27; // ESC + } + ++frame_id_; + return !exit; + } + + OutputIterator inserter() { return OutputIterator{*this}; } + + private: + void write_video(const cv::Mat& frame) { + if (!video_.isOpened()) { + open_video(frame.size()); + } + video_ << frame; + } + + void open_video(const cv::Size& size) { video_.open(path_, detail::ext2fourcc(ext_), 30, size); } + + private: + std::string path_; + std::string ext_; + MediaType type_{MediaType::kUnknown}; + int show_{1}; + size_t frame_id_{0}; + cv::VideoWriter video_; +}; + +OutputIterator& OutputIterator::operator=(const cv::Mat& frame) { + assert(output_); + output_->write(frame); + return *this; +} + +} // namespace mediaio +} // namespace utils + +#endif // MMDEPLOY_MEDIAIO_H diff --git a/demo/csrc/cpp/utils/palette.h b/demo/csrc/cpp/utils/palette.h new file mode 100644 index 0000000000..03d76a4360 --- /dev/null +++ b/demo/csrc/cpp/utils/palette.h @@ -0,0 +1,94 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_PALETTE_H +#define MMDEPLOY_PALETTE_H + +#include +#include +#include +#include +#include +#include + +namespace utils { + +struct Palette { + std::vector data; + static Palette get(const std::string& path); + static Palette get(int n); +}; + +inline Palette Palette::get(const std::string& path) { + if (path == "coco") { + Palette p{{ + {220, 20, 60}, {119, 11, 32}, {0, 0, 142}, {0, 0, 230}, {106, 0, 228}, + {0, 60, 100}, {0, 80, 100}, {0, 0, 70}, {0, 0, 192}, {250, 170, 30}, + {100, 170, 30}, {220, 220, 0}, {175, 116, 175}, {250, 0, 30}, {165, 42, 42}, + {255, 77, 255}, {0, 226, 252}, {182, 182, 255}, {0, 82, 0}, {120, 166, 157}, + {110, 76, 0}, {174, 57, 255}, {199, 100, 0}, {72, 0, 118}, {255, 179, 240}, + {0, 125, 92}, {209, 0, 151}, {188, 208, 182}, {0, 220, 176}, {255, 99, 164}, + {92, 0, 73}, {133, 129, 255}, {78, 180, 255}, {0, 228, 0}, {174, 255, 243}, + {45, 89, 255}, {134, 134, 103}, {145, 148, 174}, {255, 208, 186}, {197, 226, 255}, + {171, 134, 1}, {109, 63, 54}, {207, 138, 255}, {151, 0, 95}, {9, 80, 61}, + {84, 105, 51}, {74, 65, 105}, {166, 196, 102}, {208, 195, 210}, {255, 109, 65}, + {0, 143, 149}, {179, 0, 194}, {209, 99, 106}, {5, 121, 0}, {227, 255, 205}, + {147, 186, 208}, {153, 69, 1}, {3, 95, 161}, {163, 255, 0}, {119, 0, 170}, + {0, 182, 199}, {0, 165, 120}, {183, 130, 88}, {95, 32, 0}, {130, 114, 135}, + {110, 129, 133}, {166, 74, 118}, {219, 142, 185}, {79, 210, 114}, {178, 90, 62}, + {65, 70, 15}, {127, 167, 115}, {59, 105, 106}, {142, 108, 45}, {196, 172, 0}, + {95, 54, 80}, {128, 76, 255}, {201, 57, 1}, {246, 0, 122}, {191, 162, 208}, + }}; + for (auto& x : p.data) { + std::swap(x[0], x[2]); + } + return p; + } else if (path == "cityscapes") { + Palette p{{ + {128, 64, 128}, {244, 35, 232}, {70, 70, 70}, {102, 102, 156}, {190, 153, 153}, + {153, 153, 153}, {250, 170, 30}, {220, 220, 0}, {107, 142, 35}, {152, 251, 152}, + {70, 130, 180}, {220, 20, 60}, {255, 0, 0}, {0, 0, 142}, {0, 0, 70}, + {0, 60, 100}, {0, 80, 100}, {0, 0, 230}, {119, 11, 32}, + }}; + for (auto& x : p.data) { + std::swap(x[0], x[2]); + } + return p; + } + std::ifstream ifs(path); + if (!ifs.is_open()) { + std::cout << "error: failed to open palette data file: " << path << "\n"; + std::abort(); + } + Palette p; + int n = 0; + ifs >> n; + for (int i = 0; i < n; ++i) { + cv::Vec3b x{}; + ifs >> x[0] >> x[1] >> x[2]; + p.data.push_back(x); + } + return p; +} + +inline Palette Palette::get(int n) { + std::vector samples(n * 100); + std::vector indices(samples.size()); + std::iota(indices.begin(), indices.end(), 0); + std::mt19937 gen; // NOLINT + std::uniform_int_distribution uniform_dist(0, 255); + for (auto& x : samples) { + x = {(float)uniform_dist(gen), (float)uniform_dist(gen), (float)uniform_dist(gen)}; + } + std::vector centers; + cv::kmeans(samples, n, indices, cv::TermCriteria(cv::TermCriteria::Type::COUNT, 10, 0), 1, + cv::KMEANS_PP_CENTERS, centers); + Palette p; + for (const auto& c : centers) { + p.data.emplace_back((uchar)c.x, (uchar)c.y, (uchar)c.z); + } + return p; +} + +} // namespace utils + +#endif // MMDEPLOY_PALETTE_H diff --git a/demo/csrc/cpp/utils/skeleton.h b/demo/csrc/cpp/utils/skeleton.h new file mode 100644 index 0000000000..59b0df415a --- /dev/null +++ b/demo/csrc/cpp/utils/skeleton.h @@ -0,0 +1,89 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_SKELETON_H +#define MMDEPLOY_SKELETON_H + +#include +#include +#include +#include +#include + +namespace utils { + +struct Skeleton { + std::vector> links; + std::vector palette; + std::vector link_colors; + std::vector point_colors; + static Skeleton get(const std::string& path); +}; + +const Skeleton& gCocoSkeleton() { + static const Skeleton inst{ + { + {15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12}, + {5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1}, + {0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}, + }, + { + {255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255}, + {153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255}, + {255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102}, + {51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255}, + }, + {0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16}, + {16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0}, + }; + return inst; +} + +// n_links +// u0, v0, u1, v1, ..., un-1, vn-1 +// n_palette +// b0, g0, r0, ..., bn-1, gn-1, rn-1 +// n_link_color +// i0, i1, ..., in-1 +// n_point_color +// j0, j1, ..., jn-1 +inline Skeleton Skeleton::get(const std::string& path) { + if (path == "coco") { + return gCocoSkeleton(); + } + std::ifstream ifs(path); + if (!ifs.is_open()) { + std::cout << "error: failed to open skeleton data file: " << path << "\n"; + std::abort(); + } + Skeleton skel; + int n = 0; + ifs >> n; + for (int i = 0; i < n; ++i) { + int u{}, v{}; + ifs >> u >> v; + skel.links.emplace_back(u, v); + } + ifs >> n; + for (int i = 0; i < n; ++i) { + int b{}, g{}, r{}; + ifs >> b >> g >> r; + skel.palette.emplace_back(b, g, r); + } + ifs >> n; + for (int i = 0; i < n; ++i) { + int x{}; + ifs >> x; + skel.link_colors.push_back(x); + } + ifs >> n; + for (int i = 0; i < n; ++i) { + int x{}; + ifs >> x; + skel.point_colors.push_back(x); + } + return skel; +} + +} // namespace utils + +#endif // MMDEPLOY_SKELETON_H diff --git a/demo/csrc/cpp/utils/visualize.h b/demo/csrc/cpp/utils/visualize.h new file mode 100644 index 0000000000..e9d8493f58 --- /dev/null +++ b/demo/csrc/cpp/utils/visualize.h @@ -0,0 +1,252 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_VISUALIZE_H +#define MMDEPLOY_VISUALIZE_H + +#include +#include +#include +#include + +#include "opencv2/highgui/highgui.hpp" +#include "opencv2/imgproc/imgproc.hpp" +#include "palette.h" +#include "skeleton.h" + +namespace utils { + +class Visualize { + public: + class Session { + public: + explicit Session(Visualize& v, const cv::Mat& frame) : v_(v) { + if (v_.size_) { + scale_ = (float)v_.size_ / (float)std::max(frame.cols, frame.rows); + } + cv::Mat img; + if (v.background_ == "black") { + img = cv::Mat::zeros(frame.size(), CV_8UC3); + } else { + img = frame; + if (img.channels() == 1) { + cv::cvtColor(img, img, cv::COLOR_GRAY2BGR); + } + } + if (scale_ != 1) { + cv::resize(img, img, {}, scale_, scale_); + } else if (img.data == frame.data) { + img = img.clone(); + } + img_ = std::move(img); + } + + void add_label(int label_id, float score, int index) { + printf("label: %d, label_id: %d, score: %.4f\n", index, label_id, score); + auto size = .5f * static_cast(img_.rows + img_.cols); + offset_ += add_text(to_text(label_id, score), {1, (float)offset_}, size) + 2; + } + + int add_text(const std::string& text, const cv::Point2f& origin, float size) { + static constexpr const int font_face = cv::FONT_HERSHEY_SIMPLEX; + static constexpr const int thickness = 1; + static constexpr const auto max_font_scale = .5f; + static constexpr const auto min_font_scale = .25f; + float font_scale{}; + if (size < 20) { + font_scale = min_font_scale; + } else if (size > 200) { + font_scale = max_font_scale; + } else { + font_scale = min_font_scale + (size - 20) / (200 - 20) * (max_font_scale - min_font_scale); + } + int baseline{}; + auto text_size = cv::getTextSize(text, font_face, font_scale, thickness, &baseline); + cv::Rect rect(origin + cv::Point2f(0, text_size.height + 2 * thickness), + origin + cv::Point2f(text_size.width, 0)); + rect &= cv::Rect({}, img_.size()); + img_(rect) *= .35f; + cv::putText(img_, text, origin + cv::Point2f(0, text_size.height), font_face, font_scale, + cv::Scalar::all(255), thickness, cv::LINE_AA); + return text_size.height; + } + + static std::string to_text(int label_id, float score) { + std::stringstream ss; + ss << label_id << ": " << std::fixed << std::setprecision(1) << score * 100; + return ss.str(); + } + + template + void add_det(const mmdeploy_rect_t& rect, int label_id, float score, const Mask* mask, + int index) { + printf("bbox %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n", index, + rect.left, rect.top, rect.right, rect.bottom, label_id, score); + if (mask) { + fprintf(stdout, "mask %d, height=%d, width=%d\n", index, mask->height, mask->width); + auto x0 = (int)std::max(std::floor(rect.left) - 1, 0.f); + auto y0 = (int)std::max(std::floor(rect.top) - 1, 0.f); + add_instance_mask({x0, y0}, rand(), mask->data, mask->height, mask->width); + } + add_bbox(rect, label_id, score); + } + + void add_instance_mask(const cv::Point& origin, int color_id, const char* mask_data, int mask_h, + int mask_w, float alpha = .5f) { + auto color = v_.palette_.data[color_id % v_.palette_.data.size()]; + auto x_end = std::min(origin.x + mask_w, img_.cols); + auto y_end = std::min(origin.y + mask_h, img_.rows); + auto img_data = img_.ptr(); + for (int i = origin.y; i < y_end; ++i) { + for (int j = origin.x; j < x_end; ++j) { + if (mask_data[(i - origin.y) * mask_w + (j - origin.x)]) { + img_data[i * img_.cols + j] = img_data[i * img_.cols + j] * (1 - alpha) + color * alpha; + } + } + } + } + + void add_bbox(mmdeploy_rect_t rect, int label_id, float score) { + rect.left *= scale_; + rect.right *= scale_; + rect.top *= scale_; + rect.bottom *= scale_; + if (label_id >= 0 && score > 0) { + auto area = std::max(0.f, (rect.right - rect.left) * (rect.bottom - rect.top)); + add_text(to_text(label_id, score), {rect.left, rect.top}, std::sqrt(area)); + } + cv::rectangle(img_, cv::Point2f(rect.left, rect.top), cv::Point2f(rect.right, rect.bottom), + cv::Scalar(0, 255, 0)); + } + + void add_text_det(mmdeploy_point_t bbox[4], float score, const char* text, size_t text_size, + int index) { + printf("bbox[%d]: (%.2f, %.2f), (%.2f, %.2f), (%.2f, %.2f), (%.2f, %.2f)\n", index, // + bbox[0].x, bbox[0].y, // + bbox[1].x, bbox[1].y, // + bbox[2].x, bbox[2].y, // + bbox[3].x, bbox[3].y); + std::vector poly_points; + cv::Point2f center{}; + for (int i = 0; i < 4; ++i) { + poly_points.emplace_back(bbox[i].x * scale_, bbox[i].y * scale_); + center += cv::Point2f(poly_points.back()); + } + cv::polylines(img_, poly_points, true, cv::Scalar{0, 255, 0}, 1, cv::LINE_AA); + if (text) { + auto area = cv::contourArea(poly_points); + fprintf(stdout, "text[%d]: %s\n", index, text); + add_text(std::string(text, text + text_size), center / 4, std::sqrt(area)); + } + } + + void add_rotated_det(const float bbox[5], int label_id, float score) { + float xc = bbox[0] * scale_; + float yc = bbox[1] * scale_; + float w = bbox[2] * scale_; + float h = bbox[3] * scale_; + float ag = bbox[4]; + float wx = w / 2 * std::cos(ag); + float wy = w / 2 * std::sin(ag); + float hx = -h / 2 * std::sin(ag); + float hy = h / 2 * std::cos(ag); + cv::Point2f p1{xc - wx - hx, yc - wy - hy}; + cv::Point2f p2{xc + wx - hx, yc + wy - hy}; + cv::Point2f p3{xc + wx + hx, yc + wy + hy}; + cv::Point2f p4{xc - wx + hx, yc - wy + hy}; + cv::Point2f c = .25f * (p1 + p2 + p3 + p4); + cv::drawContours( + img_, + std::vector>{{p1 * scale_, p2 * scale_, p3 * scale_, p4 * scale_}}, + -1, {0, 255, 0}, 2, cv::LINE_AA); + add_text(to_text(label_id, score), c, std::sqrt(w * h)); + } + + void add_mask(int height, int width, int n_classes, const int* mask, const float* score) { + cv::Mat color_mask = cv::Mat::zeros(height, width, CV_8UC3); + auto n_pix = color_mask.total(); + + // compute top 1 idx if score (CHW) is available + cv::Mat_ top; + if (!mask && score) { + top = cv::Mat_::zeros(height, width); + for (auto c = 1; c < n_classes; ++c) { + top.forEach([&](int& x, const int* idx) { + auto offset = idx[0] * width + idx[1]; + if (score[c * n_pix + offset] > score[x * n_pix + offset]) { + x = c; + } + }); + } + mask = top.ptr(); + } + + if (mask) { + // palette look-up + color_mask.forEach([&](cv::Vec3b& x, const int* idx) { + auto& palette = v_.palette_.data; + x = palette[mask[idx[0] * width + idx[1]] % palette.size()]; + }); + + if (color_mask.size() != img_.size()) { + cv::resize(color_mask, color_mask, img_.size()); + } + + // blend mask and background image + cv::addWeighted(img_, .5, color_mask, .5, 0., img_); + } + } + + void add_pose(const mmdeploy_point_t* pts, const float* scores, int32_t pts_size, double thr) { + auto& skel = v_.skeleton_; + std::vector used(pts_size); + std::vector is_end_point(pts_size); + for (size_t i = 0; i < skel.links.size(); ++i) { + auto u = skel.links[i].first; + auto v = skel.links[i].second; + is_end_point[u] = is_end_point[v] = 1; + if (scores[u] > thr && scores[v] > thr) { + used[u] = used[v] = 1; + cv::Point2f p0(pts[u].x, pts[u].y); + cv::Point2f p1(pts[v].x, pts[v].y); + cv::line(img_, p0 * scale_, p1 * scale_, skel.palette[skel.link_colors[i]], 1, + cv::LINE_AA); + } + } + for (size_t i = 0; i < pts_size; ++i) { + if (!is_end_point[i] && scores[i] > thr || used[i]) { + cv::Point2f p(pts[i].x, pts[i].y); + cv::circle(img_, p * scale_, 1, skel.palette[skel.point_colors[i]], 2, cv::LINE_AA); + } + } + } + + cv::Mat get() { return img_; } + + private: + Visualize& v_; + float scale_{1}; + int offset_{1}; + cv::Mat img_; + }; + + explicit Visualize(int size = 0) : size_(size) { palette_ = Palette::get(32); } + + Session get_session(const cv::Mat& frame) { return Session(*this, frame); } + + void set_skeleton(const Skeleton& skeleton) { skeleton_ = skeleton; } + + void set_palette(const Palette& palette) { palette_ = palette; } + + void set_background(const std::string& background) { background_ = background; } + + private: + friend Session; + Skeleton skeleton_; + Palette palette_; + std::string background_; + int size_{}; +}; + +} // namespace utils + +#endif // MMDEPLOY_VISUALIZE_H diff --git a/demo/csrc/cpp/video_cls.cxx b/demo/csrc/cpp/video_cls.cxx index b69c2cb57c..3d87ee4f7b 100644 --- a/demo/csrc/cpp/video_cls.cxx +++ b/demo/csrc/cpp/video_cls.cxx @@ -3,8 +3,8 @@ #include #include "mmdeploy/video_recognizer.hpp" -#include "opencv2/imgcodecs/imgcodecs.hpp" #include "opencv2/videoio.hpp" +#include "utils/argparse.h" void SampleFrames(const char* video_path, std::map& buffer, std::vector& clips, int clip_len, int frame_interval = 1, @@ -57,28 +57,26 @@ void SampleFrames(const char* video_path, std::map& buffer, } } +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(video, "Input video path"); +DEFINE_ARG_int32(clip_len, "Clip length"); +DEFINE_ARG_int32(frame_interval, "Frame interval"); +DEFINE_ARG_int32(num_clips, "Number of clips"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); + int main(int argc, char* argv[]) { - if (argc != 7) { - fprintf(stderr, - "usage:\n video_cls device_name model_path video_path video_path clip_len " - "frame_interval num_clips\n"); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto video_path = argv[3]; - - int clip_len = std::stoi(argv[4]); - int frame_interval = std::stoi(argv[5]); - int num_clips = std::stoi(argv[6]); std::map buffer; std::vector clips; - mmdeploy::VideoSampleInfo clip_info = {clip_len, num_clips}; - SampleFrames(video_path, buffer, clips, clip_len, frame_interval, num_clips); + mmdeploy::VideoSampleInfo clip_info = {ARGS_clip_len, ARGS_num_clips}; + SampleFrames(ARGS_video.c_str(), buffer, clips, ARGS_clip_len, ARGS_frame_interval, + ARGS_num_clips); - mmdeploy::Model model(model_path); - mmdeploy::VideoRecognizer recognizer(model, mmdeploy::Device{device_name, 0}); + mmdeploy::Model model(ARGS_model); + mmdeploy::VideoRecognizer recognizer(model, mmdeploy::Device{FLAGS_device}); auto res = recognizer.Apply(clips, clip_info); diff --git a/demo/python/image_segmentation.py b/demo/python/image_segmentation.py index 32391f4345..e70b088917 100644 --- a/demo/python/image_segmentation.py +++ b/demo/python/image_segmentation.py @@ -35,6 +35,8 @@ def main(): segmentor = Segmentor( model_path=args.model_path, device_name=args.device_name, device_id=0) seg = segmentor(img) + if seg.dtype == np.float32: + seg = np.argmax(seg, axis=0) palette = get_palette() color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) diff --git a/tests/test_csrc/preprocess/test_permute.cpp b/tests/test_csrc/preprocess/test_permute.cpp new file mode 100644 index 0000000000..a12a4aad6e --- /dev/null +++ b/tests/test_csrc/preprocess/test_permute.cpp @@ -0,0 +1,121 @@ + +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "catch.hpp" +#include "mmdeploy/core/mat.h" +#include "mmdeploy/core/tensor.h" +#include "mmdeploy/core/utils/device_utils.h" +#include "mmdeploy/operation/managed.h" +#include "mmdeploy/operation/vision.h" +#include "mmdeploy/preprocess/transform/transform.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace framework; +using namespace std; +using namespace mmdeploy::test; + +template +bool CheckEqual(const Tensor& res, const vector& expected) { + auto r = res.data(); + auto e = expected.data(); + for (int i = 0; i < expected.size(); i++) { + if (r[i] != e[i]) { + return false; + } + } + return true; +} + +template +void TestPermute(const Tensor& src, const vector& axes, const vector& expected) { + auto gResource = MMDeployTestResources::Get(); + for (auto const& device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + ::mmdeploy::operation::Context ctx(device, stream); + auto permute = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); + Tensor dst; + auto ret = permute.Apply(src, dst, axes); + REQUIRE(!ret.has_error()); + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(dst, kHost, stream); + REQUIRE(CheckEqual(host_tensor.value(), expected)); + } +} + +void TestPermuteWrongArgs(const Tensor& src) { + int sz = src.shape().size(); + vector oaxes(sz); + std::iota(oaxes.begin(), oaxes.end(), 0); + + auto gResource = MMDeployTestResources::Get(); + for (auto const& device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + ::mmdeploy::operation::Context ctx(device, stream); + auto permute = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); + Tensor dst; + { + auto axes = oaxes; + axes[0]--; + auto ret = permute.Apply(src, dst, axes); + REQUIRE(ret.has_error()); + } + { + auto axes = oaxes; + axes.back()++; + auto ret = permute.Apply(src, dst, axes); + REQUIRE(ret.has_error()); + } + { + auto axes = oaxes; + axes[0] = axes[1]; + auto ret = permute.Apply(src, dst, axes); + REQUIRE(ret.has_error()); + } + } +} + +TEST_CASE("operation Permute", "[permute]") { + const Device kHost{"cpu"}; + const int kSize = 2 * 3 * 2 * 4; + vector data(kSize); + std::iota(data.begin(), data.end(), 0); // [0, 48) + TensorDesc desc = {kHost, DataType::kINT8, {kSize}}; + Tensor tensor(desc); + memcpy(tensor.data(), data.data(), data.size() * sizeof(uint8_t)); + + SECTION("permute: wrong axes") { + Tensor src = tensor; + src.Reshape({6, 8}); + TestPermuteWrongArgs(src); + } + + SECTION("permute: dims 4") { + Tensor src = tensor; + src.Reshape({2, 3, 2, 4}); + vector axes = {1, 0, 3, 2}; + vector expected = {0, 4, 1, 5, 2, 6, 3, 7, 24, 28, 25, 29, 26, 30, 27, 31, + 8, 12, 9, 13, 10, 14, 11, 15, 32, 36, 33, 37, 34, 38, 35, 39, + 16, 20, 17, 21, 18, 22, 19, 23, 40, 44, 41, 45, 42, 46, 43, 47}; + Tensor dst(src.desc()); + memcpy(dst.data(), expected.data(), data.size() * sizeof(uint8_t)); + TestPermute(src, axes, expected); + } + + SECTION("permute: dims 5") { + Tensor src = tensor; + src.Reshape({2, 3, 1, 2, 4}); + vector axes = {2, 0, 1, 4, 3}; + vector expected = {0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15, + 16, 20, 17, 21, 18, 22, 19, 23, 24, 28, 25, 29, 26, 30, 27, 31, + 32, 36, 33, 37, 34, 38, 35, 39, 40, 44, 41, 45, 42, 46, 43, 47}; + Tensor dst(src.desc()); + memcpy(dst.data(), expected.data(), data.size() * sizeof(uint8_t)); + TestPermute(src, axes, expected); + } +}