diff --git a/csrc/mmdeploy/codebase/mmpose/CMakeLists.txt b/csrc/mmdeploy/codebase/mmpose/CMakeLists.txt index 2267b029d4..968b6e956c 100644 --- a/csrc/mmdeploy/codebase/mmpose/CMakeLists.txt +++ b/csrc/mmdeploy/codebase/mmpose/CMakeLists.txt @@ -6,7 +6,9 @@ project(mmdeploy_mmpose) file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp") mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") target_link_libraries(${PROJECT_NAME} PRIVATE - mmdeploy::transform mmdeploy_opencv_utils) + mmdeploy::transform + mmdeploy_operation + mmdeploy_opencv_utils) add_library(mmdeploy::mmpose ALIAS ${PROJECT_NAME}) set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} pose_detector CACHE INTERNAL "") diff --git a/csrc/mmdeploy/codebase/mmpose/topdown_affine.cpp b/csrc/mmdeploy/codebase/mmpose/topdown_affine.cpp index d49885fbcc..75b93fa847 100644 --- a/csrc/mmdeploy/codebase/mmpose/topdown_affine.cpp +++ b/csrc/mmdeploy/codebase/mmpose/topdown_affine.cpp @@ -7,6 +7,8 @@ #include "mmdeploy/core/tensor.h" #include "mmdeploy/core/utils/device_utils.h" #include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy/operation/managed.h" +#include "mmdeploy/operation/vision.h" #include "mmdeploy/preprocess/transform/transform.h" #include "opencv2/imgproc.hpp" #include "opencv_utils.h" @@ -32,6 +34,7 @@ class TopDownAffine : public transform::Transform { stream_ = args["context"]["stream"].get(); assert(args.contains("image_size")); from_value(args["image_size"], image_size_); + warp_affine_ = operation::Managed::Create("bilinear"); } ~TopDownAffine() override = default; @@ -39,11 +42,7 @@ class TopDownAffine : public transform::Transform { Result Apply(Value& data) override { MMDEPLOY_DEBUG("top_down_affine input: {}", data); - Device host{"cpu"}; - auto _img = data["img"].get(); - OUTCOME_TRY(auto img, MakeAvailableOnDevice(_img, host, stream_)); - stream_.Wait().value(); - auto src = cpu::Tensor2CVMat(img); + auto img = data["img"].get(); // prepare data vector bbox; @@ -62,21 +61,20 @@ class TopDownAffine : public transform::Transform { auto r = data["rotation"].get(); - cv::Mat dst; + Tensor dst; if (use_udp_) { cv::Mat trans = GetWarpMatrix(r, {c[0] * 2.f, c[1] * 2.f}, {image_size_[0] - 1.f, image_size_[1] - 1.f}, {s[0] * 200.f, s[1] * 200.f}); - - cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR); + OUTCOME_TRY(warp_affine_.Apply(img, dst, trans.ptr(), image_size_[1], image_size_[0])); } else { cv::Mat trans = GetAffineTransform({c[0], c[1]}, {s[0], s[1]}, r, {image_size_[0], image_size_[1]}); - cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR); + OUTCOME_TRY(warp_affine_.Apply(img, dst, trans.ptr(), image_size_[1], image_size_[0])); } - data["img"] = cpu::CVMat2Tensor(dst); - data["img_shape"] = {1, image_size_[1], image_size_[0], dst.channels()}; + data["img_shape"] = {1, image_size_[1], image_size_[0], dst.shape(3)}; + data["img"] = std::move(dst); data["center"] = to_value(c); data["scale"] = to_value(s); MMDEPLOY_DEBUG("output: {}", data); @@ -106,7 +104,7 @@ class TopDownAffine : public transform::Transform { theta = theta * 3.1415926 / 180; float scale_x = size_dst.width / size_target.width; float scale_y = size_dst.height / size_target.height; - cv::Mat matrix = cv::Mat(2, 3, CV_32FC1); + cv::Mat matrix = cv::Mat(2, 3, CV_32F); matrix.at(0, 0) = std::cos(theta) * scale_x; matrix.at(0, 1) = -std::sin(theta) * scale_x; matrix.at(0, 2) = @@ -142,6 +140,7 @@ class TopDownAffine : public transform::Transform { cv::Mat trans = inv ? cv::getAffineTransform(dst_points, src_points) : cv::getAffineTransform(src_points, dst_points); + trans.convertTo(trans, CV_32F); return trans; } @@ -160,6 +159,7 @@ class TopDownAffine : public transform::Transform { } protected: + operation::Managed warp_affine_; bool use_udp_{false}; vector image_size_; std::string backend_; diff --git a/csrc/mmdeploy/operation/cpu/CMakeLists.txt b/csrc/mmdeploy/operation/cpu/CMakeLists.txt index 7a4edad414..d1310baaef 100644 --- a/csrc/mmdeploy/operation/cpu/CMakeLists.txt +++ b/csrc/mmdeploy/operation/cpu/CMakeLists.txt @@ -9,7 +9,8 @@ set(SRCS resize.cpp hwc2chw.cpp normalize.cpp crop.cpp - flip.cpp) + flip.cpp + warp_affine.cpp) mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") diff --git a/csrc/mmdeploy/operation/cpu/resize.cpp b/csrc/mmdeploy/operation/cpu/resize.cpp index 8c345df17e..33c5ce313b 100644 --- a/csrc/mmdeploy/operation/cpu/resize.cpp +++ b/csrc/mmdeploy/operation/cpu/resize.cpp @@ -7,7 +7,7 @@ namespace mmdeploy::operation::cpu { class ResizeImpl : public Resize { public: - ResizeImpl(std::string interp) : interp_(std::move(interp)) {} + explicit ResizeImpl(std::string interp) : interp_(std::move(interp)) {} Result apply(const Tensor& src, Tensor& dst, int dst_h, int dst_w) override { auto src_mat = mmdeploy::cpu::Tensor2CVMat(src); diff --git a/csrc/mmdeploy/operation/cpu/warp_affine.cpp b/csrc/mmdeploy/operation/cpu/warp_affine.cpp new file mode 100644 index 0000000000..5b5914db71 --- /dev/null +++ b/csrc/mmdeploy/operation/cpu/warp_affine.cpp @@ -0,0 +1,29 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "mmdeploy/operation/vision.h" +#include "mmdeploy/utils/opencv/opencv_utils.h" + +namespace mmdeploy::operation::cpu { + +class WarpAffineImpl : public WarpAffine { + public: + explicit WarpAffineImpl(int method) : method_(method) {} + + Result apply(const Tensor& src, Tensor& dst, const float affine_matrix[6], int dst_h, + int dst_w) override { + auto src_mat = mmdeploy::cpu::Tensor2CVMat(src); + cv::Mat_ _matrix(2, 3, const_cast(affine_matrix)); + auto dst_mat = mmdeploy::cpu::WarpAffine(src_mat, _matrix, dst_h, dst_w, method_); + dst = mmdeploy::cpu::CVMat2Tensor(dst_mat); + return success(); + } + + private: + int method_; +}; + +MMDEPLOY_REGISTER_FACTORY_FUNC(WarpAffine, (cpu, 0), [](const string_view& interp) { + return std::make_unique(::mmdeploy::cpu::GetInterpolationMethod(interp).value()); +}); + +} // namespace mmdeploy::operation::cpu diff --git a/csrc/mmdeploy/operation/cuda/CMakeLists.txt b/csrc/mmdeploy/operation/cuda/CMakeLists.txt index d962d3c5bb..5e04f640bd 100644 --- a/csrc/mmdeploy/operation/cuda/CMakeLists.txt +++ b/csrc/mmdeploy/operation/cuda/CMakeLists.txt @@ -17,7 +17,8 @@ set(SRCS resize.cpp normalize.cu crop.cpp crop.cu - flip.cpp) + flip.cpp + warp_affine.cpp) mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") diff --git a/csrc/mmdeploy/operation/cuda/warp_affine.cpp b/csrc/mmdeploy/operation/cuda/warp_affine.cpp new file mode 100644 index 0000000000..4f2071c068 --- /dev/null +++ b/csrc/mmdeploy/operation/cuda/warp_affine.cpp @@ -0,0 +1,118 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy/operation/vision.h" +#include "ppl/cv/cuda/warpaffine.h" + +namespace mmdeploy::operation::cuda { + +class WarpAffineImpl : public WarpAffine { + public: + explicit WarpAffineImpl(ppl::cv::InterpolationType interp) : interp_(interp) {} + + Result apply(const Tensor& src, Tensor& dst, const float affine_matrix[6], int dst_h, + int dst_w) override { + assert(src.device() == device()); + + TensorDesc desc{device(), src.data_type(), {1, dst_h, dst_w, src.shape(3)}, src.name()}; + Tensor dst_tensor(desc); + + const auto m = affine_matrix; + auto inv = Invert(affine_matrix); + + auto cuda_stream = GetNative(stream()); + if (src.data_type() == DataType::kINT8) { + OUTCOME_TRY(Dispatch(src, dst_tensor, inv.data(), cuda_stream)); + } else if (src.data_type() == DataType::kFLOAT) { + OUTCOME_TRY(Dispatch(src, dst_tensor, inv.data(), cuda_stream)); + } else { + MMDEPLOY_ERROR("unsupported data type {}", src.data_type()); + return Status(eNotSupported); + } + + dst = std::move(dst_tensor); + return success(); + } + + private: + // ppl.cv uses inverted transform + // https://github.com/opencv/opencv/blob/bc6544c0bcfa9ca5db5e0d0551edf5c8e7da3852/modules/imgproc/src/imgwarp.cpp#L3478 + static std::array Invert(const float affine_matrix[6]) { + const auto* M = affine_matrix; + std::array inv{}; + auto iM = inv.data(); + + auto D = M[0] * M[3 + 1] - M[1] * M[3]; + D = D != 0.f ? 1.f / D : 0.f; + auto A11 = M[3 + 1] * D, A22 = M[0] * D, A12 = -M[1] * D, A21 = -M[3] * D; + auto b1 = -A11 * M[2] - A12 * M[3 + 2]; + auto b2 = -A21 * M[2] - A22 * M[3 + 2]; + + iM[0] = A11; + iM[1] = A12; + iM[2] = b1; + iM[3] = A21; + iM[3 + 1] = A22; + iM[3 + 2] = b2; + + return inv; + } + + template + auto Select(int channels) -> decltype(&ppl::cv::cuda::WarpAffine) { + switch (channels) { + case 1: + return &ppl::cv::cuda::WarpAffine; + case 3: + return &ppl::cv::cuda::WarpAffine; + case 4: + return &ppl::cv::cuda::WarpAffine; + default: + MMDEPLOY_ERROR("unsupported channels {}", channels); + return nullptr; + } + } + + template + Result Dispatch(const Tensor& src, Tensor& dst, const float affine_matrix[6], + cudaStream_t stream) { + int h = (int)src.shape(1); + int w = (int)src.shape(2); + int c = (int)src.shape(3); + int dst_h = (int)dst.shape(1); + int dst_w = (int)dst.shape(2); + + auto input = src.data(); + auto output = dst.data(); + + ppl::common::RetCode ret = 0; + + if (auto warp_affine = Select(c); warp_affine) { + ret = warp_affine(stream, h, w, w * c, input, dst_h, dst_w, dst_w * c, output, affine_matrix, + interp_, ppl::cv::BORDER_CONSTANT, 0); + } else { + return Status(eNotSupported); + } + + return ret == 0 ? success() : Result(Status(eFail)); + } + + ppl::cv::InterpolationType interp_; +}; + +static auto Create(const string_view& interp) { + ppl::cv::InterpolationType type{}; + if (interp == "bilinear") { + type = ppl::cv::InterpolationType::INTERPOLATION_LINEAR; + } else if (interp == "nearest") { + type = ppl::cv::InterpolationType::INTERPOLATION_NEAREST_POINT; + } else { + MMDEPLOY_ERROR("unsupported interpolation method: {}", interp); + throw_exception(eNotSupported); + } + return std::make_unique(type); +} + +MMDEPLOY_REGISTER_FACTORY_FUNC(WarpAffine, (cuda, 0), Create); + +} // namespace mmdeploy::operation::cuda diff --git a/csrc/mmdeploy/operation/vision.cpp b/csrc/mmdeploy/operation/vision.cpp index c7f7ba77d0..35076e2bdb 100644 --- a/csrc/mmdeploy/operation/vision.cpp +++ b/csrc/mmdeploy/operation/vision.cpp @@ -12,5 +12,6 @@ MMDEPLOY_DEFINE_REGISTRY(HWC2CHW); MMDEPLOY_DEFINE_REGISTRY(Normalize); MMDEPLOY_DEFINE_REGISTRY(Crop); MMDEPLOY_DEFINE_REGISTRY(Flip); +MMDEPLOY_DEFINE_REGISTRY(WarpAffine); } // namespace mmdeploy::operation diff --git a/csrc/mmdeploy/operation/vision.h b/csrc/mmdeploy/operation/vision.h index aea99859c5..9b65dbaaac 100644 --- a/csrc/mmdeploy/operation/vision.h +++ b/csrc/mmdeploy/operation/vision.h @@ -76,7 +76,13 @@ class Flip : public Operation { }; MMDEPLOY_DECLARE_REGISTRY(Flip, unique_ptr(int flip_code)); -// TODO: warp affine +// 2x3 OpenCV affine matrix, row major +class WarpAffine : public Operation { + public: + virtual Result apply(const Tensor& src, Tensor& dst, const float affine_matrix[6], + int dst_h, int dst_w) = 0; +}; +MMDEPLOY_DECLARE_REGISTRY(WarpAffine, unique_ptr(const string_view& interp)); } // namespace mmdeploy::operation diff --git a/csrc/mmdeploy/preprocess/transform/load.cpp b/csrc/mmdeploy/preprocess/transform/load.cpp index 5640d1c478..57879d5b4b 100644 --- a/csrc/mmdeploy/preprocess/transform/load.cpp +++ b/csrc/mmdeploy/preprocess/transform/load.cpp @@ -48,6 +48,12 @@ class PrepareImage : public Transform { Result Apply(Value& data) override { MMDEPLOY_DEBUG("input: {}", data); + + // early exit + if (data.contains("img") && data["img"].is_any()) { + return success(); + } + assert(data.contains("ori_img")); Mat src_mat = data["ori_img"].get(); diff --git a/csrc/mmdeploy/utils/opencv/opencv_utils.cpp b/csrc/mmdeploy/utils/opencv/opencv_utils.cpp index b2801cb3ea..d410d5dcc8 100644 --- a/csrc/mmdeploy/utils/opencv/opencv_utils.cpp +++ b/csrc/mmdeploy/utils/opencv/opencv_utils.cpp @@ -106,23 +106,34 @@ Tensor CVMat2Tensor(const cv::Mat& mat) { return Tensor{desc, data}; } +Result GetInterpolationMethod(const std::string_view& method) { + if (method == "bilinear") { + return cv::INTER_LINEAR; + } else if (method == "nearest") { + return cv::INTER_NEAREST; + } else if (method == "area") { + return cv::INTER_AREA; + } else if (method == "bicubic") { + return cv::INTER_CUBIC; + } else if (method == "lanczos") { + return cv::INTER_LANCZOS4; + } + MMDEPLOY_ERROR("unsupported interpolation method: {}", method); + return Status(eNotSupported); +} + cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width, const std::string& interpolation) { cv::Mat dst(dst_height, dst_width, src.type()); - if (interpolation == "bilinear") { - cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_LINEAR); - } else if (interpolation == "nearest") { - cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_NEAREST); - } else if (interpolation == "area") { - cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_AREA); - } else if (interpolation == "bicubic") { - cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_CUBIC); - } else if (interpolation == "lanczos") { - cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_LANCZOS4); - } else { - MMDEPLOY_ERROR("{} interpolation is not supported", interpolation); - assert(0); - } + auto method = GetInterpolationMethod(interpolation).value(); + cv::resize(src, dst, dst.size(), method); + return dst; +} + +cv::Mat WarpAffine(const cv::Mat& src, const cv::Mat& affine_matrix, int dst_height, int dst_width, + int interpolation) { + cv::Mat dst(dst_height, dst_width, src.type()); + cv::warpAffine(src, dst, affine_matrix, dst.size(), interpolation); return dst; } diff --git a/csrc/mmdeploy/utils/opencv/opencv_utils.h b/csrc/mmdeploy/utils/opencv/opencv_utils.h index 6f5e432b95..0c9646466d 100644 --- a/csrc/mmdeploy/utils/opencv/opencv_utils.h +++ b/csrc/mmdeploy/utils/opencv/opencv_utils.h @@ -18,6 +18,8 @@ MMDEPLOY_API cv::Mat Tensor2CVMat(const framework::Tensor& tensor); MMDEPLOY_API framework::Mat CVMat2Mat(const cv::Mat& mat, PixelFormat format); MMDEPLOY_API framework::Tensor CVMat2Tensor(const cv::Mat& mat); +MMDEPLOY_API Result GetInterpolationMethod(const std::string_view& method); + /** * @brief resize an image to specified size * @@ -29,6 +31,9 @@ MMDEPLOY_API framework::Tensor CVMat2Tensor(const cv::Mat& mat); MMDEPLOY_API cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width, const std::string& interpolation); +MMDEPLOY_API cv::Mat WarpAffine(const cv::Mat& src, const cv::Mat& affine_matrix, int dst_height, + int dst_width, int interpolation); + /** * @brief crop an image * diff --git a/demo/csrc/cpp/pose_tracker.cpp b/demo/csrc/cpp/pose_tracker.cpp index 896fe75653..1ddab89784 100644 --- a/demo/csrc/cpp/pose_tracker.cpp +++ b/demo/csrc/cpp/pose_tracker.cpp @@ -1,5 +1,8 @@ +#include +#include + #include "mmdeploy/archive/json_archive.h" #include "mmdeploy/archive/value_archive.h" #include "mmdeploy/common.hpp" @@ -15,9 +18,17 @@ const auto config_json = R"( { "type": "Pipeline", - "input": ["data", "use_det", "state"], + "input": ["img", "use_det", "state"], "output": "targets", "tasks": [ + { + "type": "Task", + "module": "Transform", + "name": "preload", + "input": "img", + "output": "data", + "transforms": [ { "type": "LoadImageFromFile" } ] + }, { "type": "Cond", "input": ["use_det", "data"], @@ -32,7 +43,7 @@ const auto config_json = R"( "type": "Task", "module": "ProcessBboxes", "input": ["dets", "data", "state"], - "output": "rois" + "output": ["rois", "track_ids"] }, { "input": "*rois", @@ -45,7 +56,7 @@ const auto config_json = R"( "type": "Task", "module": "TrackPose", "scheduler": "pool", - "input": ["keypoints", "state"], + "input": ["keypoints", "track_ids", "state"], "output": "targets" } ] @@ -57,26 +68,38 @@ namespace mmdeploy { #define REGISTER_SIMPLE_MODULE(name, fn) \ MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (name, 0), [](const Value&) { return CreateTask(fn); }); -std::optional> keypoints_to_bbox(const std::vector& keypoints, - const std::vector& scores, float img_h, - float img_w, float scale = 1.5, - float kpt_thr = 0.3) { - auto valid = false; +#define POSE_TRACKER_DEBUG(...) MMDEPLOY_INFO(__VA_ARGS__) + +using std::vector; +using Bbox = std::array; +using Bboxes = vector; +using Point = cv::Point2f; +using Points = vector; +using Score = float; +using Scores = vector; + +// scale = 1.5, kpt_thr = 0.3 +std::optional keypoints_to_bbox(const Points& keypoints, const Scores& scores, float img_h, + float img_w, float scale, float kpt_thr, int min_keypoints) { + int valid = 0; auto x1 = static_cast(img_w); auto y1 = static_cast(img_h); auto x2 = 0.f; auto y2 = 0.f; for (size_t i = 0; i < keypoints.size(); ++i) { auto& kpt = keypoints[i]; - if (scores[i] > kpt_thr) { + if (scores[i] >= kpt_thr) { x1 = std::min(x1, kpt.x); y1 = std::min(y1, kpt.y); x2 = std::max(x2, kpt.x); y2 = std::max(y2, kpt.y); - valid = true; + ++valid; } } - if (!valid) { + if (min_keypoints < 0) { + min_keypoints = (static_cast(scores.size()) + 1) / 2; + } + if (valid < min_keypoints) { return std::nullopt; } auto xc = .5f * (x1 + x2); @@ -92,223 +115,781 @@ std::optional> keypoints_to_bbox(const std::vector Predict(float t) = 0; + virtual cv::Mat_ Correct(const cv::Mat_& x) = 0; +}; + +class OneEuroFilter : public Filter { + public: + explicit OneEuroFilter(const cv::Mat_& x, float beta, float fc_min, float fc_d) + : x_(x.clone()), beta_(beta), fc_min_(fc_min), fc_d_(fc_d) { + v_ = cv::Mat::zeros(x_.size(), x.type()); + } + + cv::Mat_ Predict(float t) override { return x_ + v_; } + + cv::Mat_ Correct(const cv::Mat_& x) override { + auto a_v = SmoothingFactor(fc_d_); + v_ = ExponentialSmoothing(a_v, x - x_, v_); + auto fc = fc_min_ + beta_ * (float)cv::norm(v_); + auto a_x = SmoothingFactor(fc); + x_ = ExponentialSmoothing(a_x, x, x_); + return x_.clone(); + } + + private: + static float SmoothingFactor(float cutoff) { + static constexpr float kPi = 3.1415926; + auto r = 2 * kPi * cutoff; + return r / (r + 1); + } + + static cv::Mat_ ExponentialSmoothing(float a, const cv::Mat_& x, + const cv::Mat_& x0) { + return a * x + (1 - a) * x0; + } + + private: + cv::Mat_ x_; + cv::Mat_ v_; + float beta_; + float fc_min_; + float fc_d_; +}; + +template +class PointFilterArray : public Filter { + public: + template + explicit PointFilterArray(const Points& ps, const Args&... args) { + for (const auto& p : ps) { + fs_.emplace_back(cv::Mat_(p, false), args...); + } + } + + cv::Mat_ Predict(float t) override { + cv::Mat_ m(fs_.size() * 2, 1); + for (int i = 0; i < fs_.size(); ++i) { + cv::Range r(i * 2, i * 2 + 2); + fs_[i].Predict(1).copyTo(m.rowRange(r)); + } + return m.reshape(0, fs_.size()); + } + + cv::Mat_ Correct(const cv::Mat_& x) override { + cv::Mat_ m(fs_.size() * 2, 1); + auto _x = x.reshape(1, x.rows * x.cols); + for (int i = 0; i < fs_.size(); ++i) { + cv::Range r(i * 2, i * 2 + 2); + fs_[i].Correct(_x.rowRange(r)).copyTo(m.rowRange(r)); + } + return m.reshape(0, fs_.size()); + } + + private: + vector fs_; +}; + +class TrackerFilter { + public: + using Points = vector; + + explicit TrackerFilter(float c_beta, float c_fc_min, float c_fc_d, float k_beta, float k_fc_min, + float k_fc_d, const Bbox& bbox, const Points& kpts) + : n_kpts_(kpts.size()) { + c_ = std::make_unique(cv::Mat_(Center(bbox)), c_beta, c_fc_min, c_fc_d); + s_ = std::make_unique(cv::Mat_(Scale(bbox)), 0, 1, 0); + kpts_ = std::make_unique>(kpts, k_beta, k_fc_min, k_fc_d); + } + + std::pair Predict() { + cv::Point2f c; + c_->Predict(1).copyTo(cv::Mat(c, false)); + cv::Point2f s; + s_->Predict(0).copyTo(cv::Mat(s, false)); + Points p(n_kpts_); + kpts_->Predict(1).copyTo(cv::Mat(p, false).reshape(1)); + return {GetBbox(c, s), std::move(p)}; + } + + std::pair Correct(const Bbox& bbox, const Points& kpts) { + cv::Point2f c; + c_->Correct(cv::Mat_(Center(bbox), false)).copyTo(cv::Mat(c, false)); + cv::Point2f s; + s_->Correct(cv::Mat_(Scale(bbox), false)).copyTo(cv::Mat(s, false)); + Points p(kpts.size()); + kpts_->Correct(cv::Mat(kpts, false)).copyTo(cv::Mat(p, false).reshape(1)); + return {GetBbox(c, s), std::move(p)}; + } + + private: + static cv::Point2f Center(const Bbox& bbox) { + return {.5f * (bbox[0] + bbox[2]), .5f * (bbox[1] + bbox[3])}; + } + static cv::Point2f Scale(const Bbox& bbox) { + return {bbox[2] - bbox[0], bbox[3] - bbox[1]}; + // return {std::log(bbox[2] - bbox[0]), std::log(bbox[3] - bbox[1])}; + } + static Bbox GetBbox(const cv::Point2f& center, const cv::Point2f& scale) { + // cv::Point2f half_size(.5 * std::exp(scale.x), .5 * std::exp(scale.y)); + Point half_size(.5f * scale.x, .5f * scale.y); + auto lo = center - half_size; + auto hi = center + half_size; + return {lo.x, lo.y, hi.x, hi.y}; + } + int n_kpts_; + std::unique_ptr c_; + std::unique_ptr s_; + std::unique_ptr kpts_; +}; + struct Track { - std::vector> keypoints; - std::vector> scores; - std::vector> bboxes; + vector keypoints; + vector scores; + vector avg_scores; + vector bboxes; + vector is_missing; int64_t track_id{-1}; + std::shared_ptr filter; + Bbox bbox_pred{}; + Points kpts_pred; + int64_t age{0}; + int64_t n_missing{0}; }; struct TrackInfo { - std::vector tracks; + vector tracks; int64_t next_id{0}; }; -MMDEPLOY_REGISTER_TYPE_ID(TrackInfo, 0xcfe87980aa895d3a); // randomly generated type id +static inline float Area(const Bbox& bbox) { return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]); } + +struct TrackerParams { + // detector params + int det_interval = 5; // detection interval + int det_label = 0; // label used to filter detections + float det_min_bbox_size = 100; // threshold for sqrt(area(bbox)) + float det_thr = .5f; // confidence threshold used to filter detections + float det_nms_thr = .7f; // detection nms threshold + + // pose model params + int pose_max_num_bboxes = 1; // max num of bboxes for pose model per frame + int pose_min_keypoints = -1; // min of visible key-points for valid bbox, -1 -> len(kpts)/2 + float pose_min_bbox_size = 64; // threshold for sqrt(area(bbox)) + vector sigmas; // sigmas for key-points + + // tracker params + float track_nms_oks_thr = .5f; // OKS threshold for suppressing duplicated key-points + float track_kpts_thr = .6f; // threshold for key-point visibility + float track_oks_thr = .3f; // OKS assignment threshold + float track_iou_thr = .3f; // IOU assignment threshold + float track_bbox_scale = 1.25f; // scale factor for bboxes + int track_max_missing = 10; // max number of missing frames before track removal + float track_missing_momentum = .95f; // extrapolation momentum for missing tracks + int track_n_history = 10; // track history length + + // filter params for bbox center + float filter_c_beta = .005; + float filter_c_fc_min = .05; + float filter_c_fc_d = 1.; + // filter params for key-points + float filter_k_beta = .0075; + float filter_k_fc_min = .1; + float filter_k_fc_d = .25; +}; + +class Tracker { + public: + explicit Tracker(const TrackerParams& _params) : params(_params) {} + // xyxy format + float IntersectionOverUnion(const std::array& a, const std::array& b) { + auto x1 = std::max(a[0], b[0]); + auto y1 = std::max(a[1], b[1]); + auto x2 = std::min(a[2], b[2]); + auto y2 = std::min(a[3], b[3]); + + auto inter_area = std::max(0.f, x2 - x1) * std::max(0.f, y2 - y1); -Value::Array GetObjectsByTracking(Value& state, int img_h, int img_w) { - Value::Array objs; - auto& track_info = state["track_info"].get_ref(); - for (auto& track : track_info.tracks) { - auto bbox = keypoints_to_bbox(track.keypoints.back(), track.scores.back(), - static_cast(img_h), static_cast(img_w)); - if (bbox) { - objs.push_back({{"bbox", to_value(*bbox)}}); + auto a_area = Area(a); + auto b_area = Area(b); + auto union_area = a_area + b_area - inter_area; + + if (union_area == 0.f) { + return 0; } + + return inter_area / union_area; } - return objs; -} -Value ProcessBboxes(const Value& detections, const Value& data, Value state) { - assert(state.is_pointer()); - Value::Array bboxes; - if (detections.is_array()) { // has detections - auto& dets = detections.array(); - for (const auto& det : dets) { - if (det["label_id"].get() == 0 && det["score"].get() >= .3f) { - bboxes.push_back(det); + // TopDownAffine's internal logic for mapping pose detector inputs + Bbox MapBbox(const Bbox& box) { + Point p0(box[0], box[1]); + Point p1(box[2], box[3]); + auto c = .5f * (p0 + p1); + auto s = p1 - p0; + static constexpr std::array image_size{192.f, 256.f}; + float aspect_ratio = image_size[0] * 1.0 / image_size[1]; + if (s.x > aspect_ratio * s.y) { + s.y = s.x / aspect_ratio; + } else if (s.x < aspect_ratio * s.y) { + s.x = s.y * aspect_ratio; + } + s.x *= 1.25f; + s.y *= 1.25f; + p0 = c - .5f * s; + p1 = c + .5f * s; + return {p0.x, p0.y, p1.x, p1.y}; + } + + template + vector SuppressNonMaximum(const vector& scores, const vector& similarities, + vector is_valid, float thresh) { + assert(is_valid.size() == scores.size()); + vector indices(scores.size()); + std::iota(indices.begin(), indices.end(), 0); + // stable sort, useful when the scores are equal + std::sort(indices.begin(), indices.end(), [&](int i, int j) { return scores[i] > scores[j]; }); + // suppress similar samples + for (int i = 0; i < indices.size(); ++i) { + if (auto u = indices[i]; is_valid[u]) { + for (int j = i + 1; j < indices.size(); ++j) { + if (auto v = indices[j]; is_valid[v]) { + if (similarities[u * scores.size() + v] >= thresh) { + is_valid[v] = false; + } + } + } } } - MMDEPLOY_INFO("bboxes by detection: {}", bboxes.size()); - state["bboxes"] = bboxes; - } else { // no detections, use tracked results - auto img_h = state["img_shape"][0].get(); - auto img_w = state["img_shape"][1].get(); - bboxes = GetObjectsByTracking(state, img_h, img_w); - MMDEPLOY_INFO("GetObjectsByTracking: {}", bboxes.size()); + return is_valid; } - // attach bboxes to image data - for (auto& bbox : bboxes) { - auto img = data["ori_img"].get(); - auto box = from_value>(bbox["bbox"]); - cv::Rect rect(cv::Rect2f(cv::Point2f(box[0], box[1]), cv::Point2f(box[2], box[3]))); - bbox = Value::Object{ - {"ori_img", img}, {"bbox", {rect.x, rect.y, rect.width, rect.height}}, {"rotation", 0.f}}; + + struct Detections { + Bboxes bboxes; + Scores scores; + vector labels; }; - return bboxes; -} -REGISTER_SIMPLE_MODULE(ProcessBboxes, ProcessBboxes); -// xyxy format -float ComputeIoU(const std::array& a, const std::array& b) { - auto x1 = std::max(a[0], b[0]); - auto y1 = std::max(a[1], b[1]); - auto x2 = std::min(a[2], b[2]); - auto y2 = std::min(a[3], b[3]); + void GetObjectsByDetection(const Detections& dets, vector& bboxes, + vector& track_ids, vector& types) const { + auto& [_bboxes, _scores, _labels] = dets; + for (size_t i = 0; i < _bboxes.size(); ++i) { + if (_labels[i] == params.det_label && _scores[i] > params.det_thr && + Area(_bboxes[i]) >= params.det_min_bbox_size * params.det_min_bbox_size) { + bboxes.push_back(_bboxes[i]); + track_ids.push_back(-1); + types.push_back(1); + } + } + } + + void GetObjectsByTracking(vector& bboxes, vector& track_ids, + vector& types) const { + for (auto& track : track_info.tracks) { + std::optional bbox; + if (track.n_missing) { + bbox = track.bbox_pred; + } else { + bbox = keypoints_to_bbox(track.kpts_pred, track.scores.back(), static_cast(frame_h), + static_cast(frame_w), params.track_bbox_scale, + params.track_kpts_thr, params.pose_min_keypoints); + } + if (bbox && Area(*bbox) >= params.pose_min_bbox_size * params.pose_min_bbox_size) { + bboxes.push_back(*bbox); + track_ids.push_back(track.track_id); + types.push_back(track.n_missing ? 0 : 2); + } + } + } + + std::tuple, vector> ProcessBboxes(const std::optional& dets) { + vector bboxes; + vector track_ids; - auto inter_area = std::max(0.f, x2 - x1) * std::max(0.f, y2 - y1); + // 2 - visible tracks + // 1 - detection + // 0 - missing tracks + vector types; - auto a_area = (a[2] - a[0]) * (a[3] - a[1]); - auto b_area = (b[2] - b[0]) * (b[3] - b[1]); - auto union_area = a_area + b_area - inter_area; + if (dets) { + GetObjectsByDetection(*dets, bboxes, track_ids, types); + } - if (union_area == 0.f) { - return 0; + GetObjectsByTracking(bboxes, track_ids, types); + + vector is_valid_bboxes(bboxes.size(), 1); + + auto count = [&] { + std::array acc{}; + for (size_t i = 0; i < is_valid_bboxes.size(); ++i) { + if (is_valid_bboxes[i]) { + ++acc[types[i]]; + } + } + return acc; + }; + POSE_TRACKER_DEBUG("frame {}, bboxes {}", frame_id, count()); + + vector> ranks; + ranks.reserve(bboxes.size()); + for (int i = 0; i < bboxes.size(); ++i) { + ranks.emplace_back(types[i], Area(bboxes[i])); + } + + vector iou(ranks.size() * ranks.size()); + for (int i = 0; i < bboxes.size(); ++i) { + for (int j = 0; j < i; ++j) { + iou[i * bboxes.size() + j] = iou[j * bboxes.size() + i] = + IntersectionOverUnion(bboxes[i], bboxes[j]); + } + } + + is_valid_bboxes = + SuppressNonMaximum(ranks, iou, std::move(is_valid_bboxes), params.det_nms_thr); + POSE_TRACKER_DEBUG("frame {}, bboxes after nms: {}", frame_id, count()); + + vector idxs; + idxs.reserve(bboxes.size()); + for (int i = 0; i < bboxes.size(); ++i) { + if (is_valid_bboxes[i]) { + idxs.push_back(i); + } + } + + std::stable_sort(idxs.begin(), idxs.end(), [&](int i, int j) { return ranks[i] > ranks[j]; }); + std::fill(is_valid_bboxes.begin(), is_valid_bboxes.end(), 0); + { + vector tmp_bboxes; + vector tmp_track_ids; + for (const auto& i : idxs) { + if (tmp_bboxes.size() >= params.pose_max_num_bboxes) { + break; + } + tmp_bboxes.push_back(bboxes[i]); + tmp_track_ids.push_back(track_ids[i]); + is_valid_bboxes[i] = 1; + } + bboxes = std::move(tmp_bboxes); + track_ids = std::move(tmp_track_ids); + } + + POSE_TRACKER_DEBUG("frame {}, bboxes after sort: {}", frame_id, count()); + + pose_bboxes.clear(); + for (const auto& bbox : bboxes) { + // pose_bboxes.push_back(MapBbox(bbox)); + pose_bboxes.push_back(bbox); + } + + return {bboxes, track_ids}; } - return inter_area / union_area; -} + float ObjectKeypointSimilarity(const Points& pts_a, const Bbox& box_a, const Points& pts_b, + const Bbox& box_b) { + assert(pts_a.size() == sigmas.size()); + assert(pts_b.size() == sigmas.size()); + auto scale = [](const Bbox& bbox) -> float { + auto a = bbox[2] - bbox[0]; + auto b = bbox[3] - bbox[1]; + return std::sqrt(a * a + b * b); + }; + auto oks = [](const Point& pa, const Point& pb, float s, float k) { + return std::exp(-(pa - pb).dot(pa - pb) / (2.f * s * s * k * k)); + }; + auto sum = 0.f; + const auto s = .5f * (scale(box_a) + scale(box_b)); + for (int i = 0; i < params.sigmas.size(); ++i) { + sum += oks(pts_a[i], pts_b[i], s, params.sigmas[i]); + } + sum /= static_cast(params.sigmas.size()); + return sum; + } -void UpdateTrack(Track& track, std::vector& keypoints, std::vector& score, - const std::array& bbox, int n_history) { - if (track.scores.size() == n_history) { - std::rotate(track.keypoints.begin(), track.keypoints.begin() + 1, track.keypoints.end()); - std::rotate(track.scores.begin(), track.scores.begin() + 1, track.scores.end()); - std::rotate(track.bboxes.begin(), track.bboxes.begin() + 1, track.bboxes.end()); - track.keypoints.back() = std::move(keypoints); - track.scores.back() = std::move(score); - track.bboxes.back() = bbox; - } else { - track.keypoints.push_back(std::move(keypoints)); - track.scores.push_back(std::move(score)); - track.bboxes.push_back(bbox); + void UpdateTrack(Track& track, Points kpts, Scores score, const Bbox& bbox, int is_missing) { + auto avg_score = std::accumulate(score.begin(), score.end(), 0.f) / score.size(); + if (track.scores.size() == params.track_n_history) { + std::rotate(track.keypoints.begin(), track.keypoints.begin() + 1, track.keypoints.end()); + std::rotate(track.scores.begin(), track.scores.begin() + 1, track.scores.end()); + std::rotate(track.bboxes.begin(), track.bboxes.begin() + 1, track.bboxes.end()); + std::rotate(track.avg_scores.begin(), track.avg_scores.begin() + 1, track.avg_scores.end()); + std::rotate(track.is_missing.begin(), track.is_missing.begin() + 1, track.is_missing.end()); + track.keypoints.back() = std::move(kpts); + track.scores.back() = std::move(score); + track.bboxes.back() = bbox; + track.avg_scores.back() = avg_score; + track.is_missing.back() = is_missing; + } else { + track.keypoints.push_back(std::move(kpts)); + track.scores.push_back(std::move(score)); + track.bboxes.push_back(bbox); + track.avg_scores.push_back(avg_score); + track.is_missing.push_back(is_missing); + } + ++track.age; + track.n_missing = is_missing ? track.n_missing + 1 : 0; } -} -std::vector> GreedyAssignment(const std::vector& scores, - int n_rows, int n_cols, float thr) { - std::vector used_rows(n_rows); - std::vector used_cols(n_cols); - std::vector> assignment; - assignment.reserve(std::max(n_rows, n_cols)); - while (true) { - auto max_score = 0.f; - int max_row = -1; - int max_col = -1; - for (int i = 0; i < n_rows; ++i) { - if (!used_rows[i]) { - for (int j = 0; j < n_cols; ++j) { - if (!used_cols[j]) { - if (scores[i * n_cols + j] > max_score) { - max_score = scores[i * n_cols + j]; - max_row = i; - max_col = j; + vector> GreedyAssignment(const vector& scores, + vector& is_valid_rows, + vector& is_valid_cols, float thr) { + const auto n_rows = is_valid_rows.size(); + const auto n_cols = is_valid_cols.size(); + vector> assignment; + assignment.reserve(std::max(n_rows, n_cols)); + while (true) { + auto max_score = 0.f; + int max_row = -1; + int max_col = -1; + for (int i = 0; i < n_rows; ++i) { + if (is_valid_rows[i]) { + for (int j = 0; j < n_cols; ++j) { + if (is_valid_cols[j]) { + if (scores[i * n_cols + j] > max_score) { + max_score = scores[i * n_cols + j]; + max_row = i; + max_col = j; + } } } } } + if (max_score < thr) { + break; + } + is_valid_rows[max_row] = 0; + is_valid_cols[max_col] = 0; + assignment.emplace_back(max_row, max_col, max_score); } - if (max_score < thr) { - break; + return assignment; + } + + vector SuppressOverlappingBboxes( + const vector& keypoints, const vector& scores, + const vector& is_present, // bbox from a visible track? + const vector& bboxes, vector is_valid, const vector& sigmas, + float oks_thr) { + assert(keypoints.size() == is_valid.size()); + assert(scores.size() == is_valid.size()); + assert(bboxes.size() == is_valid.size()); + const auto size = is_valid.size(); + vector oks(size * size); + for (int i = 0; i < size; ++i) { + if (is_valid[i]) { + for (int j = 0; j < i; ++j) { + if (is_valid[j]) { + oks[i * size + j] = oks[j * size + i] = + ObjectKeypointSimilarity(keypoints[i], bboxes[i], keypoints[j], bboxes[j]); + } + } + } + } + vector> ranks; + ranks.reserve(size); + for (int i = 0; i < size; ++i) { + auto& s = scores[i]; + auto avg = std::accumulate(s.begin(), s.end(), 0.f) / static_cast(s.size()); + // prevents bboxes from missing tracks to suppress visible tracks + ranks.emplace_back(is_present[i], avg); } - used_rows[max_row] = 1; - used_cols[max_col] = 1; - assignment.emplace_back(max_row, max_col, max_score); + return SuppressNonMaximum(ranks, oks, is_valid, oks_thr); } - return assignment; -} -void TrackStep(std::vector>& keypoints, - std::vector>& scores, TrackInfo& track_info, int img_h, int img_w, - float iou_thr, int min_keypoints, int n_history) { - auto& tracks = track_info.tracks; + void TrackStep(vector& keypoints, vector& scores, + const vector& track_ids) { + auto& tracks = track_info.tracks; - std::vector new_tracks; - new_tracks.reserve(tracks.size()); + vector new_tracks; + new_tracks.reserve(tracks.size()); - std::vector> bboxes; - bboxes.reserve(keypoints.size()); + vector bboxes(keypoints.size()); + vector is_valid_bboxes(keypoints.size(), 1); - std::vector indices; - indices.reserve(keypoints.size()); + pose_results.clear(); - for (size_t i = 0; i < keypoints.size(); ++i) { - if (auto bbox = keypoints_to_bbox(keypoints[i], scores[i], img_h, img_w, 1.f, 0.f)) { - bboxes.push_back(*bbox); - indices.push_back(i); + // key-points to bboxes + for (size_t i = 0; i < keypoints.size(); ++i) { + if (auto bbox = + keypoints_to_bbox(keypoints[i], scores[i], frame_h, frame_w, params.track_bbox_scale, + params.track_kpts_thr, params.pose_min_keypoints)) { + bboxes[i] = *bbox; + pose_results.push_back(*bbox); + } else { + is_valid_bboxes[i] = false; + // MMDEPLOY_INFO("frame {}: invalid key-points {}", frame_id, scores[i]); + } } - } - const auto n_rows = static_cast(bboxes.size()); - const auto n_cols = static_cast(tracks.size()); + vector is_present(is_valid_bboxes.size()); + for (int i = 0; i < track_ids.size(); ++i) { + for (const auto& t : tracks) { + if (t.track_id == track_ids[i]) { + is_present[i] = !t.n_missing; + break; + } + } + } + is_valid_bboxes = + SuppressOverlappingBboxes(keypoints, scores, is_present, bboxes, is_valid_bboxes, + params.sigmas, params.track_nms_oks_thr); + assert(is_valid_bboxes.size() == bboxes.size()); + + const auto n_rows = static_cast(bboxes.size()); + const auto n_cols = static_cast(tracks.size()); - std::vector similarities(n_rows * n_cols); - for (size_t i = 0; i < n_rows; ++i) { - for (size_t j = 0; j < n_cols; ++j) { - similarities[i * n_cols + j] = ComputeIoU(bboxes[i], tracks[j].bboxes.back()); + // generate similarity matrix + vector iou(n_rows * n_cols); + vector oks(n_rows * n_cols); + for (size_t i = 0; i < n_rows; ++i) { + const auto& bbox = bboxes[i]; + const auto& kpts = keypoints[i]; + for (size_t j = 0; j < n_cols; ++j) { + const auto& track = tracks[j]; + if (track_ids[i] != -1 && track_ids[i] != track.track_id) { + continue; + } + const auto index = i * n_cols + j; + iou[index] = IntersectionOverUnion(bbox, track.bbox_pred); + oks[index] = ObjectKeypointSimilarity(kpts, bbox, track.kpts_pred, track.bbox_pred); + } } - } - const auto assignment = GreedyAssignment(similarities, n_rows, n_cols, iou_thr); + vector is_valid_tracks(n_cols, 1); + // disable missing tracks in the #1 assignment + for (int i = 0; i < tracks.size(); ++i) { + if (tracks[i].n_missing) { + is_valid_tracks[i] = 0; + } + } + const auto oks_assignment = + GreedyAssignment(oks, is_valid_bboxes, is_valid_tracks, params.track_oks_thr); + + // enable missing tracks in the #2 assignment + for (int i = 0; i < tracks.size(); ++i) { + if (tracks[i].n_missing) { + is_valid_tracks[i] = 1; + } + } + const auto iou_assignment = + GreedyAssignment(iou, is_valid_bboxes, is_valid_tracks, params.track_iou_thr); - std::vector used(n_rows); - for (auto [i, j, _] : assignment) { - auto k = indices[i]; - UpdateTrack(tracks[j], keypoints[k], scores[k], bboxes[i], n_history); - new_tracks.push_back(std::move(tracks[j])); - used[i] = true; - } + POSE_TRACKER_DEBUG("frame {}, oks assignment {}", frame_id, oks_assignment); + POSE_TRACKER_DEBUG("frame {}, iou assignment {}", frame_id, iou_assignment); + + auto assignment = oks_assignment; + assignment.insert(assignment.end(), iou_assignment.begin(), iou_assignment.end()); - for (size_t i = 0; i < used.size(); ++i) { - if (used[i] == 0) { - auto k = indices[i]; - auto count = std::count_if(scores[k].begin(), scores[k].end(), [](auto x) { return x > 0; }); - if (count >= min_keypoints) { + // update assigned tracks + for (auto [i, j, _] : assignment) { + auto& track = tracks[j]; + if (track.n_missing) { + // re-initialize filter for recovering tracks + track.filter = CreateFilter(bboxes[i], keypoints[i]); + UpdateTrack(track, keypoints[i], scores[i], bboxes[i], false); + POSE_TRACKER_DEBUG("frame {}, track recovered {}", frame_id, track.track_id); + } else { + auto [bbox, kpts] = track.filter->Correct(bboxes[i], keypoints[i]); + UpdateTrack(track, std::move(kpts), std::move(scores[i]), bbox, false); + } + new_tracks.push_back(std::move(track)); + } + + // generating new tracks + for (size_t i = 0; i < is_valid_bboxes.size(); ++i) { + // only newly detected bboxes are allowed to form new tracks + if (is_valid_bboxes[i] && track_ids[i] == -1) { auto& track = new_tracks.emplace_back(); track.track_id = track_info.next_id++; - UpdateTrack(track, keypoints[k], scores[k], bboxes[i], n_history); + track.filter = CreateFilter(bboxes[i], keypoints[i]); + UpdateTrack(track, std::move(keypoints[i]), std::move(scores[i]), bboxes[i], false); + is_valid_bboxes[i] = 0; + POSE_TRACKER_DEBUG("frame {}, new track {}", frame_id, track.track_id); + } + } + + if (1) { + // diagnostic for missing tracks + int n_missing = 0; + for (int i = 0; i < is_valid_tracks.size(); ++i) { + if (is_valid_tracks[i]) { + float best_oks = 0.f; + float best_iou = 0.f; + for (int j = 0; j < is_valid_bboxes.size(); ++j) { + if (is_valid_bboxes[j]) { + best_oks = std::max(oks[j * n_cols + i], best_oks); + best_iou = std::max(iou[j * n_cols + i], best_iou); + } + } + POSE_TRACKER_DEBUG("frame {}: track missing {}, best_oks={}, best_iou={}", frame_id, + tracks[i].track_id, best_oks, best_iou); + ++n_missing; + } + } + if (n_missing) { + { + std::stringstream ss; + ss << cv::Mat_(n_rows, n_cols, oks.data()); + POSE_TRACKER_DEBUG("frame {}, oks: \n{}", frame_id, ss.str()); + } + { + std::stringstream ss; + ss << cv::Mat_(n_rows, n_cols, iou.data()); + POSE_TRACKER_DEBUG("frame {}, iou: \n{}", frame_id, ss.str()); + } + } + } + + for (int i = 0; i < is_valid_tracks.size(); ++i) { + if (is_valid_tracks[i]) { + if (auto& track = tracks[i]; track.n_missing < params.track_max_missing) { + // use predicted state to update missing tracks + auto [bbox, kpts] = track.filter->Correct(track.bbox_pred, track.kpts_pred); + vector score(track.kpts_pred.size()); + POSE_TRACKER_DEBUG("frame {}, track {}, bbox width {}", frame_id, track.track_id, + bbox[2] - bbox[0]); + UpdateTrack(track, std::move(kpts), std::move(score), bbox, true); + new_tracks.push_back(std::move(track)); + } else { + POSE_TRACKER_DEBUG("frame {}, track lost {}", frame_id, track.track_id); + } + is_valid_tracks[i] = false; + } + } + + tracks = std::move(new_tracks); + for (auto& t : tracks) { + if (t.n_missing == 0) { + std::tie(t.bbox_pred, t.kpts_pred) = t.filter->Predict(); + } else { + auto [bbox, kpts] = t.filter->Predict(); + const auto alpha = params.track_missing_momentum; + cv::Mat tmp_bbox = alpha * cv::Mat(bbox, false) + (1 - alpha) * cv::Mat(t.bbox_pred, false); + tmp_bbox.copyTo(cv::Mat(t.bbox_pred, false)); + } + } + + if (0) { + vector> summary; + for (const auto& track : tracks) { + summary.emplace_back(track.track_id, track.n_missing); + } + POSE_TRACKER_DEBUG("frame {}, track summary {}", frame_id, summary); + for (const auto& track : tracks) { + if (!track.n_missing) { + POSE_TRACKER_DEBUG("frame {}, track {}, scores {}", frame_id, track.track_id, + track.scores.back()); + } + } + } + } + + std::shared_ptr CreateFilter(const Bbox& bbox, const Points& kpts) const { + return std::make_shared( + params.filter_c_beta, params.filter_c_fc_min, params.filter_c_fc_d, params.filter_k_beta, + params.filter_k_fc_min, params.filter_k_fc_d, bbox, kpts); + } + + struct Target { + Bbox bbox; + vector keypoints; + Scores scores; + MMDEPLOY_ARCHIVE_MEMBERS(bbox, keypoints, scores); + }; + + vector TrackPose(vector keypoints, vector scores, + const vector& track_ids) { + TrackStep(keypoints, scores, track_ids); + vector targets; + for (const auto& track : track_info.tracks) { + if (track.n_missing) { + continue; } + if (auto bbox = keypoints_to_bbox(track.keypoints.back(), track.scores.back(), frame_h, + frame_w, params.track_bbox_scale, params.track_kpts_thr, + params.pose_min_keypoints)) { + vector kpts; + kpts.reserve(track.keypoints.back().size()); + for (const auto& kpt : track.keypoints.back()) { + kpts.emplace_back(kpt.x); + kpts.emplace_back(kpt.y); + } + targets.push_back(Target{*bbox, std::move(kpts), track.scores.back()}); + } + } + return targets; + } + + float frame_h = 0; + float frame_w = 0; + TrackInfo track_info; + + TrackerParams params; + + int frame_id = 0; + + vector pose_bboxes; + vector pose_results; +}; + +MMDEPLOY_REGISTER_TYPE_ID(Tracker, 0xcfe87980aa895d3a); + +std::tuple ProcessBboxes(const Value& det_val, const Value& data, Value state) { + auto& tracker = state.get_ref(); + + std::optional dets; + + if (det_val.is_array()) { // has detections + auto& [bboxes, scores, labels] = dets.emplace(); + for (const auto& det : det_val.array()) { + bboxes.push_back(from_value(det["bbox"])); + scores.push_back(det["score"].get()); + labels.push_back(det["label_id"].get()); } } - tracks = std::move(new_tracks); + auto [bboxes, ids] = tracker.ProcessBboxes(dets); + + Value::Array bbox_array; + Value track_ids_array; + // attach bboxes to image data + for (auto& bbox : bboxes) { + cv::Rect rect(cv::Rect2f(cv::Point2f(bbox[0], bbox[1]), cv::Point2f(bbox[2], bbox[3]))); + bbox_array.push_back({ + {"img", data["img"]}, // img + {"bbox", {rect.x, rect.y, rect.width, rect.height}}, // bbox + {"rotation", 0.f} // rotation + }); + } + + track_ids_array = to_value(ids); + return {std::move(bbox_array), std::move(track_ids_array)}; } -Value TrackPose(const Value& result, Value state) { - assert(state.is_pointer()); - assert(result.is_array()); - std::vector> keypoints; - std::vector> scores; - for (auto& output : result.array()) { +REGISTER_SIMPLE_MODULE(ProcessBboxes, ProcessBboxes); + +Value TrackPose(const Value& poses, const Value& track_indices, Value state) { + assert(poses.is_array()); + vector keypoints; + vector scores; + for (auto& output : poses.array()) { auto& k = keypoints.emplace_back(); auto& s = scores.emplace_back(); + float avg = 0.f; for (auto& kpt : output["key_points"].array()) { - k.push_back(cv::Point2f{kpt["bbox"][0].get(), kpt["bbox"][1].get()}); + k.emplace_back(kpt["bbox"][0].get(), kpt["bbox"][1].get()); s.push_back(kpt["score"].get()); + avg += s.back(); } } - auto& track_info = state["track_info"].get_ref(); - auto img_h = state["img_shape"][0].get(); - auto img_w = state["img_shape"][1].get(); - auto iou_thr = state["iou_thr"].get(); - auto min_keypoints = state["min_keypoints"].get(); - auto n_history = state["n_history"].get(); - TrackStep(keypoints, scores, track_info, img_h, img_w, iou_thr, min_keypoints, n_history); - - Value::Array targets; - for (const auto& track : track_info.tracks) { - if (auto bbox = keypoints_to_bbox(track.keypoints.back(), track.scores.back(), img_h, img_w)) { - Value::Array kpts; - kpts.reserve(track.keypoints.back().size()); - for (const auto& kpt : track.keypoints.back()) { - kpts.push_back(kpt.x); - kpts.push_back(kpt.y); - } - targets.push_back({{"bbox", to_value(*bbox)}, {"keypoints", std::move(kpts)}}); - } - } - return targets; + vector track_ids; + from_value(track_indices, track_ids); + auto& tracker = state.get_ref(); + auto targets = tracker.TrackPose(std::move(keypoints), std::move(scores), track_ids); + return to_value(targets); } + REGISTER_SIMPLE_MODULE(TrackPose, TrackPose); class PoseTracker { @@ -324,12 +905,10 @@ class PoseTracker { return Pipeline{config, context}; }()) {} - State CreateState() { // NOLINT - return make_pointer({{"frame_id", 0}, - {"n_history", 10}, - {"iou_thr", .3f}, - {"min_keypoints", 3}, - {"track_info", TrackInfo{}}}); + State CreateState(const TrackerParams& params) { + auto state = make_pointer(Tracker{params}); + auto& tracker = state.get_ref(); + return state; } Value Track(const Mat& img, State& state, int use_detector = -1) { @@ -337,19 +916,30 @@ class PoseTracker { framework::Mat mat(img.desc().height, img.desc().width, static_cast(img.desc().format), static_cast(img.desc().type), {img.desc().data, [](void*) {}}); - // TODO: get_ref is not working - auto frame_id = state["frame_id"].get(); + + auto& tracker = state.get_ref(); + if (use_detector < 0) { - use_detector = frame_id % 10 == 0; - if (use_detector) { - MMDEPLOY_WARN("use detector"); + if (tracker.frame_id % tracker.params.det_interval == 0) { + use_detector = 1; + POSE_TRACKER_DEBUG("frame {}, use detector", tracker.frame_id); + } else { + use_detector = 0; } } - state["frame_id"] = frame_id + 1; - state["img_shape"] = {mat.height(), mat.width()}; + + if (tracker.frame_id == 0) { + tracker.frame_h = static_cast(mat.height()); + tracker.frame_w = static_cast(mat.width()); + } + Value::Object data{{"ori_img", mat}}; Value input{{data}, {use_detector}, {state}}; - return pipeline_.Apply(input)[0][0]; + auto ret = pipeline_.Apply(input)[0][0]; + + ++tracker.frame_id; + + return ret; } private: @@ -360,32 +950,74 @@ class PoseTracker { using namespace mmdeploy; -void Visualize(cv::Mat& frame, const Value& result) { - static std::vector> skeleton{ +const cv::Scalar& gPalette(int index) { + static vector inst{ + {255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255}, + {153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255}, + {255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102}, + {51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255}}; + return inst[index]; +} + +void Visualize(cv::Mat& frame, const Value& result, const Bboxes& pose_bboxes, + const Bboxes& pose_results, int size) { + static vector> skeleton{ {15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12}, {5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1}, {0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}}; + static vector link_color{0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16}; + static vector kpt_color{16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0}; + auto scale = (float)size / (float)std::max(frame.cols, frame.rows); + if (scale != 1) { + cv::resize(frame, frame, {}, scale, scale); + } + auto draw_bbox = [](cv::Mat& image, Bbox bbox, const cv::Scalar& color, float scale = 1) { + std::for_each(bbox.begin(), bbox.end(), [&](auto& x) { x *= scale; }); + cv::Point p1(bbox[0], bbox[1]); + cv::Point p2(bbox[2], bbox[3]); + cv::rectangle(image, p1, p2, color); + }; const auto& targets = result.array(); for (const auto& target : targets) { auto bbox = from_value>(target["bbox"]); - auto kpts = from_value>(target["keypoints"]); - cv::Point p1(bbox[0], bbox[1]); - cv::Point p2(bbox[2], bbox[3]); - cv::rectangle(frame, p1, p2, cv::Scalar(0, 255, 0)); - for (int i = 0; i < kpts.size(); i += 2) { - cv::Point p(kpts[i], kpts[i + 1]); - cv::circle(frame, p, 1, cv::Scalar(0, 255, 255), 2, cv::LINE_AA); + auto kpts = from_value>(target["keypoints"]); + std::for_each(bbox.begin(), bbox.end(), [&](auto& x) { x *= scale; }); + std::for_each(kpts.begin(), kpts.end(), [&](auto& x) { x *= scale; }); + auto scores = from_value>(target["scores"]); + if (0) { + draw_bbox(frame, bbox, cv::Scalar(0, 255, 0)); } + constexpr auto score_thr = .5f; + vector used(kpts.size()); for (int i = 0; i < skeleton.size(); ++i) { auto [u, v] = skeleton[i]; - cv::Point p_u(kpts[u * 2], kpts[u * 2 + 1]); - cv::Point p_v(kpts[v * 2], kpts[v * 2 + 1]); - cv::line(frame, p_u, p_v, cv::Scalar(0, 255, 255), 1, cv::LINE_AA); + if (scores[u] > score_thr && scores[v] > score_thr) { + used[u] = used[v] = 1; + cv::Point p_u(kpts[u * 2], kpts[u * 2 + 1]); + cv::Point p_v(kpts[v * 2], kpts[v * 2 + 1]); + cv::line(frame, p_u, p_v, gPalette(link_color[i]), 1, cv::LINE_AA); + } + } + for (int i = 0; i < kpts.size(); i += 2) { + if (used[i / 2]) { + cv::Point p(kpts[i], kpts[i + 1]); + cv::circle(frame, p, 1, gPalette(kpt_color[i / 2]), 2, cv::LINE_AA); + } } } - cv::imshow("", frame); - cv::waitKey(1); + if (0) { + for (auto bbox : pose_bboxes) { + draw_bbox(frame, bbox, {0, 255, 255}, scale); + } + for (auto bbox : pose_results) { + draw_bbox(frame, bbox, {0, 255, 0}, scale); + } + } + static int frame_id = 0; + cv::imwrite(fmt::format("pose_{}.jpg", frame_id++), frame, {cv::IMWRITE_JPEG_QUALITY, 90}); } +// ffmpeg -f image2 -i pose_%d.jpg -vcodec hevc -crf 30 pose.mp4 + int main(int argc, char* argv[]) { const auto device_name = argv[1]; const auto det_model_path = argv[2]; @@ -396,7 +1028,14 @@ int main(int argc, char* argv[]) { Profiler profiler("pose_tracker.perf"); context.Add(profiler); PoseTracker tracker(Model(det_model_path), Model(pose_model_path), context); - auto state = tracker.CreateState(); + TrackerParams params; + // coco + params.sigmas = {0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, + 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089}; + params.pose_max_num_bboxes = 5; + params.det_interval = 5; + + auto state = tracker.CreateState(params); cv::Mat frame; std::chrono::duration dt{}; @@ -414,7 +1053,11 @@ int main(int argc, char* argv[]) { auto t1 = std::chrono::high_resolution_clock::now(); dt += t1 - t0; ++frame_id; - Visualize(frame, result); + + auto& pose_bboxes = state.get_ref().pose_bboxes; + auto& pose_results = state.get_ref().pose_results; + + Visualize(frame, result, pose_bboxes, pose_results, 1024); } MMDEPLOY_INFO("frames: {}, time {} ms", frame_id, dt.count()); diff --git a/docs/en/01-how-to-build/android.md b/docs/en/01-how-to-build/android.md index c67c73f5d7..a153c34567 100644 --- a/docs/en/01-how-to-build/android.md +++ b/docs/en/01-how-to-build/android.md @@ -121,7 +121,7 @@ MMDeploy provides a recipe as shown below for building SDK with ncnn as inferenc -DOpenCV_DIR=${OPENCV_ANDROID_SDK_DIR}/sdk/native/jni/abi-${ANDROID_ABI} \ -Dncnn_DIR=${NCNN_DIR}/build_${ANDROID_ABI}/install/lib/cmake/ncnn \ -DMMDEPLOY_TARGET_BACKENDS=ncnn \ - -DMMDEPLOY_SHARED_LIBS=ON \ + -DMMDEPLOY_SHARED_LIBS=OFF \ -DCMAKE_TOOLCHAIN_FILE=${NDK_PATH}/build/cmake/android.toolchain.cmake \ -DANDROID_ABI=${ANDROID_ABI} \ -DANDROID_PLATFORM=android-30 \ diff --git a/docs/zh_cn/01-how-to-build/android.md b/docs/zh_cn/01-how-to-build/android.md index a7f3fe5875..ae57748e4a 100644 --- a/docs/zh_cn/01-how-to-build/android.md +++ b/docs/zh_cn/01-how-to-build/android.md @@ -122,7 +122,7 @@ make -j$(nproc) install -DOpenCV_DIR=${OPENCV_ANDROID_SDK_DIR}/sdk/native/jni/abi-${ANDROID_ABI} \ -Dncnn_DIR=${NCNN_DIR}/build_${ANDROID_ABI}/install/lib/cmake/ncnn \ -DMMDEPLOY_TARGET_BACKENDS=ncnn \ - -DMMDEPLOY_SHARED_LIBS=ON \ + -DMMDEPLOY_SHARED_LIBS=OFF \ -DCMAKE_TOOLCHAIN_FILE=${NDK_PATH}/build/cmake/android.toolchain.cmake \ -DANDROID_ABI=${ANDROID_ABI} \ -DANDROID_PLATFORM=android-30 \