diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h index f895af4150..4b27fbab8a 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h @@ -36,7 +36,7 @@ typedef struct mmdeploy_pose_tracker_param_t { int32_t pose_max_num_bboxes; // threshold for visible key-points, default = 0.5 float pose_kpt_thr; - // min number of key-points for valid poses, default = -1 + // min number of key-points for valid poses (-1 indicates ceil(n_kpts/2)), default = -1 int32_t pose_min_keypoints; // scale for expanding key-points to bbox, default = 1.25 float pose_bbox_scale; diff --git a/csrc/mmdeploy/apis/cxx/CMakeLists.txt b/csrc/mmdeploy/apis/cxx/CMakeLists.txt index 19d56344a6..6a5160d9c6 100644 --- a/csrc/mmdeploy/apis/cxx/CMakeLists.txt +++ b/csrc/mmdeploy/apis/cxx/CMakeLists.txt @@ -24,4 +24,5 @@ install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/common.hpp install(DIRECTORY ${CMAKE_SOURCE_DIR}/demo/csrc/ DESTINATION example/cpp FILES_MATCHING PATTERN "*.cxx" + PATTERN "*.h" ) diff --git a/csrc/mmdeploy/codebase/mmpose/pose_tracker/utils.h b/csrc/mmdeploy/codebase/mmpose/pose_tracker/utils.h index 3643c99387..676e87157d 100644 --- a/csrc/mmdeploy/codebase/mmpose/pose_tracker/utils.h +++ b/csrc/mmdeploy/codebase/mmpose/pose_tracker/utils.h @@ -22,7 +22,7 @@ using Points = vector; using Score = float; using Scores = vector; -#define POSE_TRACKER_DEBUG(...) MMDEPLOY_INFO(__VA_ARGS__) +#define POSE_TRACKER_DEBUG(...) MMDEPLOY_DEBUG(__VA_ARGS__) // opencv3 can't construct cv::Mat from std::array template diff --git a/demo/csrc/cpp/classifier.cxx b/demo/csrc/cpp/classifier.cxx index 2d8b2d1e25..a4a12eea16 100644 --- a/demo/csrc/cpp/classifier.cxx +++ b/demo/csrc/cpp/classifier.cxx @@ -1,31 +1,43 @@ #include "mmdeploy/classifier.hpp" -#include - #include "opencv2/imgcodecs/imgcodecs.hpp" +#include "utils/argparse.h" +#include "utils/visualize.h" + +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "classifier_output.jpg", "Output image path"); int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n image_classification device_name model_path image_path\n"); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; + + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; } - mmdeploy::Model model(model_path); - mmdeploy::Classifier classifier(model, mmdeploy::Device{device_name, 0}); + // construct a classifier instance + mmdeploy::Classifier classifier(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}); - auto res = classifier.Apply(img); + // apply the classifier; the result is an array-like class holding references to + // `mmdeploy_classification_t`, will be released automatically on destruction + mmdeploy::Classifier::Result result = classifier.Apply(img); + + // visualize results + utils::Visualize v; + auto sess = v.get_session(img); + int count = 0; + for (const mmdeploy_classification_t& cls : result) { + sess.add_label(cls.label_id, cls.score, count++); + } - for (const auto& cls : res) { - fprintf(stderr, "label: %d, score: %.4f\n", cls.label_id, cls.score); + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); } return 0; diff --git a/demo/csrc/cpp/det_pose.cpp b/demo/csrc/cpp/det_pose.cpp deleted file mode 100644 index a12dd3c0b1..0000000000 --- a/demo/csrc/cpp/det_pose.cpp +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include - -#include "mmdeploy/detector.hpp" -#include "mmdeploy/pose_detector.hpp" -#include "opencv2/imgcodecs/imgcodecs.hpp" -#include "opencv2/imgproc/imgproc.hpp" - -using std::vector; - -cv::Mat Visualize(cv::Mat frame, const std::vector& poses, int size); - -int main(int argc, char* argv[]) { - const auto device_name = argv[1]; - const auto det_model_path = argv[2]; - const auto pose_model_path = argv[3]; - const auto image_path = argv[4]; - - if (argc != 5) { - std::cerr << "usage:\n\tpose_tracker device_name det_model_path pose_model_path image_path\n"; - return -1; - } - auto img = cv::imread(image_path); - if (!img.data) { - std::cerr << "failed to load image: " << image_path << "\n"; - return -1; - } - - using namespace mmdeploy; - - Context context(Device{device_name}); // create context for holding the device handle - Detector detector(Model(det_model_path), context); // create object detector - PoseDetector pose(Model(pose_model_path), context); // create pose detector - - // apply detector - auto dets = detector.Apply(img); - - // filter detections and extract bboxes for pose model - std::vector bboxes; - for (const auto& det : dets) { - if (det.label_id == 0 && det.score > .6f) { - bboxes.push_back(det.bbox); - } - } - // apply pose detector - auto poses = pose.Apply(img, bboxes); - - // visualize - auto vis = Visualize(img, {poses.begin(), poses.end()}, 1280); - cv::imwrite("det_pose_output.jpg", vis); - - return 0; -} - -struct Skeleton { - vector> skeleton; - vector palette; - vector link_color; - vector point_color; -}; - -const Skeleton& gCocoSkeleton() { - static const Skeleton inst{ - { - {15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12}, - {5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1}, - {0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}, - }, - { - {255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255}, - {153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255}, - {255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102}, - {51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255}, - }, - {0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16}, - {16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0}, - }; - return inst; -} - -cv::Mat Visualize(cv::Mat frame, const vector& poses, int size) { - auto& [skeleton, palette, link_color, point_color] = gCocoSkeleton(); - auto scale = (float)size / (float)std::max(frame.cols, frame.rows); - if (scale != 1) { - cv::resize(frame, frame, {}, scale, scale); - } else { - frame = frame.clone(); - } - for (const auto& pose : poses) { - vector kpts(&pose.point[0].x, &pose.point[pose.length - 1].y + 1); - vector scores(pose.score, pose.score + pose.length); - std::for_each(kpts.begin(), kpts.end(), [&](auto& x) { x *= scale; }); - constexpr auto score_thr = .5f; - vector used(kpts.size()); - for (size_t i = 0; i < skeleton.size(); ++i) { - auto [u, v] = skeleton[i]; - if (scores[u] > score_thr && scores[v] > score_thr) { - used[u] = used[v] = 1; - cv::Point p_u(kpts[u * 2], kpts[u * 2 + 1]); - cv::Point p_v(kpts[v * 2], kpts[v * 2 + 1]); - cv::line(frame, p_u, p_v, palette[link_color[i]], 1, cv::LINE_AA); - } - } - for (size_t i = 0; i < kpts.size(); i += 2) { - if (used[i / 2]) { - cv::Point p(kpts[i], kpts[i + 1]); - cv::circle(frame, p, 1, palette[point_color[i / 2]], 2, cv::LINE_AA); - } - } - } - return frame; -} diff --git a/demo/csrc/cpp/det_pose.cxx b/demo/csrc/cpp/det_pose.cxx new file mode 100644 index 0000000000..b18fbe003b --- /dev/null +++ b/demo/csrc/cpp/det_pose.cxx @@ -0,0 +1,75 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "mmdeploy/detector.hpp" +#include "mmdeploy/pose_detector.hpp" +#include "opencv2/imgcodecs/imgcodecs.hpp" +#include "utils/argparse.h" +#include "utils/visualize.h" + +DEFINE_ARG_string(det_model, "Object detection model path"); +DEFINE_ARG_string(pose_model, "Pose estimation model path"); +DEFINE_ARG_string(image, "Input image path"); + +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "det_pose_output.jpg", "Output image path"); +DEFINE_string(skeleton, "coco", R"(Path to skeleton data or name of predefined skeletons: "coco")"); + +DEFINE_int32(det_label, 0, "Detection label use for pose estimation"); +DEFINE_double(det_thr, .5, "Detection score threshold"); +DEFINE_double(det_min_bbox_size, -1, "Detection minimum bbox size"); + +DEFINE_double(pose_thr, 0, "Pose key-point threshold"); + +int main(int argc, char* argv[]) { + if (!utils::ParseArguments(argc, argv)) { + return -1; + } + + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; + } + + mmdeploy::Device device{FLAGS_device}; + // create object detector + mmdeploy::Detector detector(mmdeploy::Model(ARGS_det_model), device); + // create pose detector + mmdeploy::PoseDetector pose(mmdeploy::Model(ARGS_pose_model), device); + + // apply the detector, the result is an array-like class holding references to + // `mmdeploy_detection_t`, will be released automatically on destruction + mmdeploy::Detector::Result dets = detector.Apply(img); + + // filter detections and extract bboxes for pose model + std::vector bboxes; + for (const mmdeploy_detection_t& det : dets) { + if (det.label_id == FLAGS_det_label && det.score > FLAGS_det_thr) { + bboxes.push_back(det.bbox); + } + } + + // apply pose detector, if no bboxes are provided, full image will be used; the result is an + // array-like class holding references to `mmdeploy_pose_detection_t`, will be released + // automatically on destruction + mmdeploy::PoseDetector::Result poses = pose.Apply(img, bboxes); + + assert(bboxes.size() == poses.size()); + + // visualize results + utils::Visualize v; + v.set_skeleton(utils::Skeleton::get(FLAGS_skeleton)); + auto sess = v.get_session(img); + for (size_t i = 0; i < bboxes.size(); ++i) { + sess.add_bbox(bboxes[i], -1, -1); + sess.add_pose(poses[i].point, poses[i].score, poses[i].length, FLAGS_pose_thr); + } + + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); + } + + return 0; +} diff --git a/demo/csrc/cpp/detector.cxx b/demo/csrc/cpp/detector.cxx index 7009d2fd21..e5b9f58183 100644 --- a/demo/csrc/cpp/detector.cxx +++ b/demo/csrc/cpp/detector.cxx @@ -1,69 +1,47 @@ #include "mmdeploy/detector.hpp" -#include -#include -#include +#include "opencv2/imgcodecs/imgcodecs.hpp" +#include "utils/argparse.h" +#include "utils/visualize.h" -int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n object_detection device_name model_path image_path\n"); - return 1; - } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; - } - - mmdeploy::Model model(model_path); - mmdeploy::Detector detector(model, mmdeploy::Device{device_name, 0}); - - auto dets = detector.Apply(img); - - fprintf(stdout, "bbox_count=%d\n", (int)dets.size()); - - for (int i = 0; i < dets.size(); ++i) { - const auto& box = dets[i].bbox; - const auto& mask = dets[i].mask; +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "detector_output.jpg", "Output image path"); - fprintf(stdout, "box %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n", - i, box.left, box.top, box.right, box.bottom, dets[i].label_id, dets[i].score); +DEFINE_double(det_thr, .5, "Detection score threshold"); - // skip detections with invalid bbox size (bbox height or width < 1) - if ((box.right - box.left) < 1 || (box.bottom - box.top) < 1) { - continue; - } +int main(int argc, char* argv[]) { + if (!utils::ParseArguments(argc, argv)) { + return -1; + } - // skip detections less than specified score threshold - if (dets[i].score < 0.3) { - continue; - } + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; + } - // generate mask overlay if model exports masks - if (mask != nullptr) { - fprintf(stdout, "mask %d, height=%d, width=%d\n", i, mask->height, mask->width); + // construct a detector instance + mmdeploy::Detector detector(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}); - cv::Mat imgMask(mask->height, mask->width, CV_8UC1, &mask->data[0]); - auto x0 = std::max(std::floor(box.left) - 1, 0.f); - auto y0 = std::max(std::floor(box.top) - 1, 0.f); - cv::Rect roi((int)x0, (int)y0, mask->width, mask->height); + // apply the detector, the result is an array-like class holding references to + // `mmdeploy_detection_t`, will be released automatically on destruction + mmdeploy::Detector::Result dets = detector.Apply(img); - // split the RGB channels, overlay mask to a specific color channel - cv::Mat ch[3]; - split(img, ch); - int col = 0; // int col = i % 3; - cv::bitwise_or(imgMask, ch[col](roi), ch[col](roi)); - merge(ch, 3, img); + // visualize + utils::Visualize v; + auto sess = v.get_session(img); + int count = 0; + for (const mmdeploy_detection_t& det : dets) { + if (det.score > FLAGS_det_thr) { // filter bboxes + sess.add_det(det.bbox, det.label_id, det.score, det.mask, count++); } - - cv::rectangle(img, cv::Point{(int)box.left, (int)box.top}, - cv::Point{(int)box.right, (int)box.bottom}, cv::Scalar{0, 255, 0}); } - cv::imwrite("output_detection.png", img); + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); + } return 0; } diff --git a/demo/csrc/cpp/pose_tracker.cpp b/demo/csrc/cpp/pose_tracker.cpp deleted file mode 100644 index 077b3d39ea..0000000000 --- a/demo/csrc/cpp/pose_tracker.cpp +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "mmdeploy/pose_tracker.hpp" - -#include - -#include "opencv2/highgui/highgui.hpp" -#include "opencv2/imgcodecs/imgcodecs.hpp" -#include "opencv2/imgproc/imgproc.hpp" -#include "opencv2/videoio/videoio.hpp" - -struct Args { - std::string device; - std::string det_model; - std::string pose_model; - std::string video; - std::string output_dir; -}; - -Args ParseArgs(int argc, char* argv[]); - -using std::vector; -using namespace mmdeploy; - -bool Visualize(cv::Mat frame, const PoseTracker::Result& result, int size, - const std::string& output_dir, int frame_id, bool with_bbox); - -int main(int argc, char* argv[]) { - auto args = ParseArgs(argc, argv); - if (args.device.empty()) { - return 0; - } - - // create pose tracker pipeline - PoseTracker tracker(Model(args.det_model), Model(args.pose_model), Context{Device{args.device}}); - - // set parameters - PoseTracker::Params params; - params->det_min_bbox_size = 100; - params->det_interval = 1; - params->pose_max_num_bboxes = 6; - - // optionally use OKS for keypoints similarity comparison - std::array sigmas{0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, - 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089}; - params->keypoint_sigmas = sigmas.data(); - params->keypoint_sigmas_size = sigmas.size(); - - // create a tracker state for each video - PoseTracker::State state = tracker.CreateState(params); - - cv::VideoCapture video; - if (args.video.size() == 1 && std::isdigit(args.video[0])) { - video.open(std::stoi(args.video)); // open by camera index - } else { - video.open(args.video); // open video file - } - if (!video.isOpened()) { - std::cerr << "failed to open video: " << args.video << "\n"; - } - - cv::Mat frame; - int frame_id = 0; - while (true) { - video >> frame; - if (!frame.data) { - break; - } - // apply the pipeline with the tracker state and video frame - auto result = tracker.Apply(state, frame); - // visualize the results - if (!Visualize(frame, result, 1280, args.output_dir, frame_id++, false)) { - break; - } - } - - return 0; -} - -struct Skeleton { - vector> skeleton; - vector palette; - vector link_color; - vector point_color; -}; - -const Skeleton& gCocoSkeleton() { - static const Skeleton inst{ - { - {15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12}, - {5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1}, - {0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}, - }, - { - {255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255}, - {153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255}, - {255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102}, - {51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255}, - }, - {0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16}, - {16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0}, - }; - return inst; -} - -bool Visualize(cv::Mat frame, const PoseTracker::Result& result, int size, - const std::string& output_dir, int frame_id, bool with_bbox) { - auto& [skeleton, palette, link_color, point_color] = gCocoSkeleton(); - auto scale = (float)size / (float)std::max(frame.cols, frame.rows); - if (scale != 1) { - cv::resize(frame, frame, {}, scale, scale); - } else { - frame = frame.clone(); - } - auto draw_bbox = [&](std::array bbox, const cv::Scalar& color) { - std::for_each(bbox.begin(), bbox.end(), [&](auto& x) { x *= scale; }); - cv::rectangle(frame, cv::Point2f(bbox[0], bbox[1]), cv::Point2f(bbox[2], bbox[3]), color); - }; - for (const auto& r : result) { - vector kpts(&r.keypoints[0].x, &r.keypoints[0].x + r.keypoint_count * 2); - vector scores(r.scores, r.scores + r.keypoint_count); - std::for_each(kpts.begin(), kpts.end(), [&](auto& x) { x *= scale; }); - constexpr auto score_thr = .5f; - vector used(kpts.size()); - for (size_t i = 0; i < skeleton.size(); ++i) { - auto [u, v] = skeleton[i]; - if (scores[u] > score_thr && scores[v] > score_thr) { - used[u] = used[v] = 1; - cv::Point2f p_u(kpts[u * 2], kpts[u * 2 + 1]); - cv::Point2f p_v(kpts[v * 2], kpts[v * 2 + 1]); - cv::line(frame, p_u, p_v, palette[link_color[i]], 1, cv::LINE_AA); - } - } - for (size_t i = 0; i < kpts.size(); i += 2) { - if (used[i / 2]) { - cv::Point2f p(kpts[i], kpts[i + 1]); - cv::circle(frame, p, 1, palette[point_color[i / 2]], 2, cv::LINE_AA); - } - } - if (with_bbox) { - draw_bbox((std::array&)r.bbox, cv::Scalar(0, 255, 0)); - } - } - if (output_dir.empty()) { - cv::imshow("pose_tracker", frame); - return cv::waitKey(1) != 'q'; - } - auto join = [](const std::string& a, const std::string& b) { -#if _MSC_VER - return a + "\\" + b; -#else - return a + "/" + b; -#endif - }; - cv::imwrite(join(output_dir, std::to_string(frame_id) + ".jpg"), frame, - {cv::IMWRITE_JPEG_QUALITY, 90}); - return true; -} - -Args ParseArgs(int argc, char* argv[]) { - if (argc < 5 || argc > 6) { - std::cout << R"(Usage: pose_tracker device_name det_model pose_model video [output] - device_name device name for execution, e.g. "cpu", "cuda" - det_model object detection model path - pose_model pose estimation model path - video video path or camera index - output output directory, will cv::imshow if omitted -)"; - return {}; - } - Args args; - args.device = argv[1]; - args.det_model = argv[2]; - args.pose_model = argv[3]; - args.video = argv[4]; - if (argc == 6) { - args.output_dir = argv[5]; - } - return args; -} diff --git a/demo/csrc/cpp/pose_tracker.cxx b/demo/csrc/cpp/pose_tracker.cxx new file mode 100644 index 0000000000..c57cc3e4a4 --- /dev/null +++ b/demo/csrc/cpp/pose_tracker.cxx @@ -0,0 +1,68 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "mmdeploy/pose_tracker.hpp" + +#include "utils/argparse.h" +#include "utils/mediaio.h" +#include "utils/visualize.h" + +DEFINE_ARG_string(det_model, "Object detection model path"); +DEFINE_ARG_string(pose_model, "Pose estimation model path"); +DEFINE_ARG_string(input, "Input video path or camera index"); + +DEFINE_string(device, "cpu", "Device name, e.g. \"cpu\", \"cuda\""); +DEFINE_string(output, "", "Output video path or format string"); + +DEFINE_int32(output_size, 0, "Long-edge of output frames"); +DEFINE_int32(show, 1, "Delay passed to `cv::waitKey` when using `cv::imshow`; -1: disable"); + +DEFINE_string(skeleton, "coco", R"(Path to skeleton data or name of predefined skeletons: "coco")"); +DEFINE_string(background, "default", + R"(Output background, "default": original image, "black": black background)"); + +#include "pose_tracker_params.h" + +int main(int argc, char* argv[]) { + if (!utils::ParseArguments(argc, argv)) { + return -1; + } + + // create pose tracker pipeline + mmdeploy::PoseTracker tracker(mmdeploy::Model(ARGS_det_model), mmdeploy::Model(ARGS_pose_model), + mmdeploy::Device{FLAGS_device}); + + mmdeploy::PoseTracker::Params params; + // initialize tracker params with program arguments + InitTrackerParams(params); + + // create a tracker state for each video + mmdeploy::PoseTracker::State state = tracker.CreateState(params); + + utils::mediaio::Input input(ARGS_input); + utils::mediaio::Output output(FLAGS_output, FLAGS_show); + + utils::Visualize v(FLAGS_output_size); + v.set_background(FLAGS_background); + v.set_skeleton(utils::Skeleton::get(FLAGS_skeleton)); + + for (const cv::Mat& frame : input) { + // apply the pipeline with the tracker state and video frame; the result is an array-like class + // holding references to `mmdeploy_pose_tracker_target_t`, will be released automatically on + // destruction + mmdeploy::PoseTracker::Result result = tracker.Apply(state, frame); + + // visualize results + auto sess = v.get_session(frame); + for (const mmdeploy_pose_tracker_target_t& target : result) { + sess.add_pose(target.keypoints, target.scores, target.keypoint_count, FLAGS_pose_kpt_thr); + } + + // write to output stream + if (!output.write(sess.get())) { + // user requested exit by pressing 'q' + break; + } + } + + return 0; +} diff --git a/demo/csrc/cpp/pose_tracker_params.h b/demo/csrc/cpp/pose_tracker_params.h new file mode 100644 index 0000000000..2cda301869 --- /dev/null +++ b/demo/csrc/cpp/pose_tracker_params.h @@ -0,0 +1,39 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +DEFINE_int32(det_interval, 1, "Detection interval"); +DEFINE_int32(det_label, 0, "Detection label use for pose estimation"); +DEFINE_double(det_thr, 0.5, "Detection score threshold"); +DEFINE_double(det_min_bbox_size, -1, "Detection minimum bbox size"); +DEFINE_double(det_nms_thr, .7, + "NMS IOU threshold for merging detected bboxes and bboxes from tracked targets"); + +DEFINE_int32(pose_max_num_bboxes, -1, "Max number of bboxes used for pose estimation per frame"); +DEFINE_double(pose_kpt_thr, .5, "Threshold for visible key-points"); +DEFINE_int32(pose_min_keypoints, -1, + "Min number of key-points for valid poses, -1 indicates ceil(n_kpts/2)"); +DEFINE_double(pose_bbox_scale, 1.25, "Scale for expanding key-points to bbox"); +DEFINE_double( + pose_min_bbox_size, -1, + "Min pose bbox size, tracks with bbox size smaller than the threshold will be dropped"); +DEFINE_double(pose_nms_thr, 0.5, + "NMS OKS/IOU threshold for suppressing overlapped poses, useful when multiple pose " + "estimations collapse to the same target"); + +DEFINE_double(track_iou_thr, 0.4, "IOU threshold for associating missing tracks"); +DEFINE_int32(track_max_missing, 10, + "Max number of missing frames before a missing tracks is removed"); + +void InitTrackerParams(mmdeploy::PoseTracker::Params& params) { + params->det_interval = FLAGS_det_interval; + params->det_label = FLAGS_det_label; + params->det_thr = FLAGS_det_thr; + params->det_min_bbox_size = FLAGS_det_min_bbox_size; + params->pose_max_num_bboxes = FLAGS_pose_max_num_bboxes; + params->pose_kpt_thr = FLAGS_pose_kpt_thr; + params->pose_min_keypoints = FLAGS_pose_min_keypoints; + params->pose_bbox_scale = FLAGS_pose_bbox_scale; + params->pose_min_bbox_size = FLAGS_pose_min_bbox_size; + params->pose_nms_thr = FLAGS_pose_nms_thr; + params->track_iou_thr = FLAGS_track_iou_thr; + params->track_max_missing = FLAGS_track_max_missing; +} diff --git a/demo/csrc/cpp/restorer.cxx b/demo/csrc/cpp/restorer.cxx index 7c3eefd82c..378294fed4 100644 --- a/demo/csrc/cpp/restorer.cxx +++ b/demo/csrc/cpp/restorer.cxx @@ -2,33 +2,39 @@ #include "mmdeploy/restorer.hpp" -#include -#include -#include +#include "opencv2/imgcodecs/imgcodecs.hpp" +#include "opencv2/imgproc/imgproc.hpp" +#include "utils/argparse.h" + +DEFINE_ARG_string(model, "Super-resolution model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "restorer_output.jpg", "Output image path"); int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n image_restorer device_name model_path image_path\n"); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; + + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; } - using namespace mmdeploy; + // construct a restorer instance + mmdeploy::Restorer restorer{mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}}; - Restorer restorer{Model{model_path}, Device{device_name}}; + // apply restorer to the image + mmdeploy::Restorer::Result result = restorer.Apply(img); - auto result = restorer.Apply(img); + // convert to BGR + cv::Mat upsampled(result->height, result->width, CV_8UC3, result->data); + cv::cvtColor(upsampled, upsampled, cv::COLOR_RGB2BGR); - cv::Mat sr_img(result->height, result->width, CV_8UC3, result->data); - cv::cvtColor(sr_img, sr_img, cv::COLOR_RGB2BGR); - cv::imwrite("output_restorer.bmp", sr_img); + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, upsampled); + } return 0; } diff --git a/demo/csrc/cpp/rotated_detector.cxx b/demo/csrc/cpp/rotated_detector.cxx index d590273d32..3ddb5d4c1c 100644 --- a/demo/csrc/cpp/rotated_detector.cxx +++ b/demo/csrc/cpp/rotated_detector.cxx @@ -1,51 +1,46 @@ #include "mmdeploy/rotated_detector.hpp" -#include -#include -#include +#include "utils/argparse.h" +#include "utils/visualize.h" + +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "rotated_detector_output.jpg", "Output image path"); + +DEFINE_double(det_thr, 0.1, "Detection score threshold"); int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n oriented_object_detection device_name model_path image_path\n"); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; + + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; } - mmdeploy::Model model(model_path); - mmdeploy::RotatedDetector detector(model, mmdeploy::Device{device_name, 0}); + // construct a detector instance + mmdeploy::RotatedDetector detector(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}); - auto dets = detector.Apply(img); + // apply the detector, the result is an array-like class holding references to + // `mmdeploy_rotated_detection_t`, will be released automatically on destruction + mmdeploy::RotatedDetector::Result dets = detector.Apply(img); - for (const auto& det : dets) { - if (det.score < 0.1) { - continue; + // visualize results + utils::Visualize v; + auto sess = v.get_session(img); + for (const mmdeploy_rotated_detection_t& det : dets) { + if (det.score > FLAGS_det_thr) { + sess.add_rotated_det(det.rbbox, det.label_id, det.score); } - auto& rbbox = det.rbbox; - float xc = rbbox[0]; - float yc = rbbox[1]; - float w = rbbox[2]; - float h = rbbox[3]; - float ag = rbbox[4]; - float wx = w / 2 * std::cos(ag); - float wy = w / 2 * std::sin(ag); - float hx = -h / 2 * std::sin(ag); - float hy = h / 2 * std::cos(ag); - cv::Point p1 = {int(xc - wx - hx), int(yc - wy - hy)}; - cv::Point p2 = {int(xc + wx - hx), int(yc + wy - hy)}; - cv::Point p3 = {int(xc + wx + hx), int(yc + wy + hy)}; - cv::Point p4 = {int(xc - wx + hx), int(yc - wy + hy)}; - cv::drawContours(img, std::vector>{{p1, p2, p3, p4}}, -1, {0, 255, 0}, - 2); } - cv::imwrite("output_rotated_detection.png", img); + + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); + } return 0; } diff --git a/demo/csrc/cpp/segmentor.cxx b/demo/csrc/cpp/segmentor.cxx index be5100f6a0..55ba41db67 100644 --- a/demo/csrc/cpp/segmentor.cxx +++ b/demo/csrc/cpp/segmentor.cxx @@ -2,75 +2,46 @@ #include "mmdeploy/segmentor.hpp" -#include -#include -#include -#include -#include #include #include -using namespace std; +#include "utils/argparse.h" +#include "utils/mediaio.h" +#include "utils/visualize.h" -vector gen_palette(int num_classes) { - std::mt19937 gen; - std::uniform_int_distribution uniform_dist(0, 255); - - vector palette; - palette.reserve(num_classes); - for (auto i = 0; i < num_classes; ++i) { - palette.emplace_back(uniform_dist(gen), uniform_dist(gen), uniform_dist(gen)); - } - return palette; -} +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "segmentor_output.jpg", "Output image path"); +DEFINE_string(palette, "cityscapes", + R"(Path to palette data or name of predefined palettes: "cityscapes")"); int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n image_segmentation device_name model_path image_path\n"); - return 1; - } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto image_path = argv[3]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - using namespace mmdeploy; + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; + } - Segmentor segmentor{Model{model_path}, Device{device_name}}; + mmdeploy::Segmentor segmentor{mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}}; - auto result = segmentor.Apply(img); + // apply the detector, the result is an array-like class holding a reference to + // `mmdeploy_segmentation_t`, will be released automatically on destruction + mmdeploy::Segmentor::Result seg = segmentor.Apply(img); - auto palette = gen_palette(result->classes + 1); + // visualize + utils::Visualize v; + v.set_palette(utils::Palette::get(FLAGS_palette)); + auto sess = v.get_session(img); + sess.add_mask(seg->height, seg->width, seg->classes, seg->mask, seg->score); - cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3); - int pos = 0; - int total = color_mask.rows * color_mask.cols; - std::vector idxs(result->classes); - for (auto iter = color_mask.begin(); iter != color_mask.end(); ++iter) { - // output mask - if (result->mask) { - *iter = palette[result->mask[pos++]]; - } - // output score - if (result->score) { - std::iota(idxs.begin(), idxs.end(), 0); - auto k = - std::max_element(idxs.begin(), idxs.end(), - [&](int i, int j) { - return result->score[pos + i * total] < result->score[pos + j * total]; - }) - - idxs.begin(); - *iter = palette[k]; - pos += 1; - } + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); } - img = img * 0.5 + color_mask * 0.5; - cv::imwrite("output_segmentation.png", img); - return 0; } diff --git a/demo/csrc/cpp/text_ocr.cxx b/demo/csrc/cpp/text_ocr.cxx index 853d681071..6c8fdb055b 100644 --- a/demo/csrc/cpp/text_ocr.cxx +++ b/demo/csrc/cpp/text_ocr.cxx @@ -1,46 +1,57 @@ -#include -#include #include #include "mmdeploy/text_detector.hpp" #include "mmdeploy/text_recognizer.hpp" +#include "utils/argparse.h" +#include "utils/mediaio.h" +#include "utils/visualize.h" -int main(int argc, char* argv[]) { - if (argc != 5) { - fprintf(stderr, "usage:\n ocr device_name det_model_path reg_model_path image_path\n"); - return 1; - } - const auto device_name = argv[1]; - auto det_model_path = argv[2]; - auto reg_model_path = argv[3]; - auto image_path = argv[4]; - cv::Mat img = cv::imread(image_path); - if (!img.data) { - fprintf(stderr, "failed to load image: %s\n", image_path); - return 1; - } +DEFINE_ARG_string(det_model, "Text detection model path"); +DEFINE_ARG_string(reg_model, "Text recognition model path"); +DEFINE_ARG_string(image, "Input image path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); +DEFINE_string(output, "text_ocr_output.jpg", "Output image path"); - using namespace mmdeploy; +using mmdeploy::TextDetector; +using mmdeploy::TextRecognizer; - TextDetector detector{Model(det_model_path), Device(device_name)}; - TextRecognizer recognizer{Model(reg_model_path), Device(device_name)}; +int main(int argc, char* argv[]) { + if (!utils::ParseArguments(argc, argv)) { + return -1; + } - auto bboxes = detector.Apply(img); - auto texts = recognizer.Apply(img, {bboxes.begin(), bboxes.size()}); + cv::Mat img = cv::imread(ARGS_image); + if (img.empty()) { + fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str()); + return -1; + } - for (int i = 0; i < bboxes.size(); ++i) { - fprintf(stdout, "box[%d]: %s\n", i, texts[i].text); - std::vector poly_points; - for (const auto& pt : bboxes[i].bbox) { - fprintf(stdout, "x: %.2f, y: %.2f, ", pt.x, pt.y); - poly_points.emplace_back((int)pt.x, (int)pt.y); - } - fprintf(stdout, "\n"); - cv::polylines(img, poly_points, true, cv::Scalar{0, 255, 0}); + mmdeploy::Device device(FLAGS_device); + TextDetector detector{mmdeploy::Model(ARGS_det_model), device}; + TextRecognizer recognizer{mmdeploy::Model(ARGS_reg_model), device}; + + // apply the detector, the result is an array-like class holding references to + // `mmdeploy_text_detection_t`, will be released automatically on destruction + TextDetector::Result bboxes = detector.Apply(img); + + // apply recognizer, if no bboxes are provided, full image will be used; the result is an + // array-like class holding references to `mmdeploy_text_recognition_t`, will be released + // automatically on destruction + TextRecognizer::Result texts = recognizer.Apply(img, {bboxes.begin(), bboxes.size()}); + + // visualize results + utils::Visualize v; + auto sess = v.get_session(img); + for (size_t i = 0; i < bboxes.size(); ++i) { + mmdeploy_text_detection_t& bbox = bboxes[i]; + mmdeploy_text_recognition_t& text = texts[i]; + sess.add_text_det(bbox.bbox, bbox.score, text.text, text.length, i); } - cv::imwrite("output_ocr.png", img); + if (!FLAGS_output.empty()) { + cv::imwrite(FLAGS_output, sess.get()); + } return 0; } diff --git a/demo/csrc/cpp/utils/argparse.h b/demo/csrc/cpp/utils/argparse.h new file mode 100644 index 0000000000..5c94c8afad --- /dev/null +++ b/demo/csrc/cpp/utils/argparse.h @@ -0,0 +1,272 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_ARGPARSE_H +#define MMDEPLOY_ARGPARSE_H + +#include +#include +#include +#include +#include +#include +#include + +#define DEFINE_int32(name, init, msg) _MMDEPLOY_DEFINE_FLAG(int32_t, name, init, msg) +#define DEFINE_double(name, init, msg) _MMDEPLOY_DEFINE_FLAG(double, name, init, msg) +#define DEFINE_string(name, init, msg) _MMDEPLOY_DEFINE_FLAG(std::string, name, init, msg) + +#define DEFINE_ARG_int32(name, msg) _MMDEPLOY_DEFINE_ARG(int32_t, name, msg) +#define DEFINE_ARG_double(name, msg) _MMDEPLOY_DEFINE_ARG(double, name, msg) +#define DEFINE_ARG_string(name, msg) _MMDEPLOY_DEFINE_ARG(std::string, name, msg) + +namespace utils { + +class ArgParse { + public: + template + static T Register(const std::string& type, const std::string& name, T init, + const std::string& msg, void* ptr) { + instance()._Register(type, name, msg, true, ptr); + return init; + } + + template + static T Register(const std::string& type, const std::string& name, const std::string& msg, + void* ptr) { + instance()._Register(type, name, msg, false, ptr); + return {}; + } + + static bool ParseArguments(int argc, char* argv[]) { + if (!instance()._Parse(argc, argv)) { + ShowUsageWithFlags(argv[0]); + return false; + } + return true; + } + + static void ShowUsageWithFlags(const char* argv0) { instance()._ShowUsageWithFlags(argv0); } + + private: + static ArgParse& instance() { + static ArgParse inst; + return inst; + } + + struct Info { + std::string name; + std::string type; + std::string msg; + bool is_flag; + void* ptr; + }; + + void _Register(std::string type, const std::string& name, const std::string& msg, bool is_flag, + void* ptr) { + if (type == "std::string") { + type = "string"; + } else if (type == "int32_t") { + type = "int32"; + } + infos_.push_back({name, type, msg, is_flag, ptr}); + } + + bool _Parse(int argc, char* argv[]) { + int arg_idx{-1}; + std::vector args(infos_.size()); + std::vector used(infos_.size()); + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + return false; + } + if (argv[i][0] == '-' && argv[i][1] == '-') { + // parse flag key-value pair (--x=y or --x y) + int eq{-1}; + for (int k = 2; argv[i][k]; ++k) { + if (argv[i][k] == '=') { + eq = k; + break; + } + } + std::string key; + std::string val; + if (eq >= 0) { + key = std::string(argv[i] + 2, argv[i] + eq); + val = std::string(argv[i] + eq + 1); + } else { + key = std::string(argv[i] + 2); + if (i < argc - 1) { + val = argv[++i]; + } + } + bool found{}; + for (int j = 0; j < infos_.size(); ++j) { + auto& flag = infos_[j]; + if (key == flag.name) { + args[j] = val; + found = used[j] = 1; + break; + } + } + if (!found) { + std::cout << "error: unknown option: " << key << std::endl; + return false; + } + } else { + for (arg_idx++; arg_idx < infos_.size(); ++arg_idx) { + if (!infos_[arg_idx].is_flag) { + args[arg_idx] = argv[i]; + used[arg_idx] = 1; + break; + } + } + if (arg_idx == infos_.size()) { + std::cout << "error: unknown argument: " << argv[i] << std::endl; + return false; + } + } + } + std::vector missing; + for (arg_idx++; arg_idx < infos_.size(); ++arg_idx) { + if (!infos_[arg_idx].is_flag) { + missing.push_back(infos_[arg_idx].name); + } + } + if (!missing.empty()) { + std::cout << "error: the following arguments are required:"; + for (int i = 0; i < missing.size(); ++i) { + std::cout << " " << missing[i]; + if (i != missing.size() - 1) { + std::cout << ","; + } + } + std::cout << "\n"; + return false; + } + + for (int i = 0; i < infos_.size(); ++i) { + if (used[i]) { + try { + parse_str(infos_[i], args[i]); + } catch (...) { + std::cout << "error: failed to parse " << infos_[i].name << ": " << args[i] << std::endl; + return false; + } + } + } + + return true; + } + + static void parse_str(Info& info, const std::string& str) { + if (info.type == "int32") { + *static_cast(info.ptr) = std::stoi(str); + } else if (info.type == "double") { + *static_cast(info.ptr) = std::stod(str); + } else if (info.type == "string") { + *static_cast(info.ptr) = str; + } else { + // pass + } + } + + static std::string get_default_str(const Info& info) { + if (info.type == "int32") { + return std::to_string(*static_cast(info.ptr)); + } else if (info.type == "double") { + std::ostringstream os; + os << std::setprecision(3) << *static_cast(info.ptr); + return os.str(); + } else if (info.type == "string") { + return "\"" + *(static_cast(info.ptr)) + "\""; + } else { + return ""; + } + } + + void _ShowUsageWithFlags(const char* argv0) const { + ShowUsage(argv0); + static constexpr const auto kLineLength = 80; + std::cout << std::endl; + int max_name_length = 0; + for (const auto& info : infos_) { + max_name_length = std::max(max_name_length, (int)info.name.length()); + } + max_name_length += 4; + auto name_col_size = max_name_length + 1; + auto msg_col_size = kLineLength - name_col_size; + std::cout << "required arguments:\n"; + ShowFlags(name_col_size, msg_col_size, false); + std::cout << std::endl; + std::cout << "optional arguments:\n"; + ShowFlags(name_col_size, msg_col_size, true); + } + + void ShowFlags(int name_col_size, int msg_col_size, bool is_flag) const { + for (const auto& info : infos_) { + if (info.is_flag != is_flag) { + continue; + } + std::string name = " "; + if (info.is_flag) { + name.append("--"); + } + name.append(info.name); + while (name.length() < name_col_size) { + name.append(" "); + } + std::cout << name; + std::string msg = info.msg; + while (msg.length() > msg_col_size) { // insert line-breaks when msg is too long + auto pos = msg.rend() - std::find(std::make_reverse_iterator(msg.begin() + msg_col_size), + msg.rend(), ' '); + std::cout << msg.substr(0, pos - 1) << std::endl; + std::cout << std::string(name_col_size, ' '); + msg = msg.substr(pos); + } + std::cout << msg; + std::string type; + type.append("[").append(info.type); + if (info.is_flag) { + type.append(" = ").append(get_default_str(info)); + } + type.append("]"); + if (msg.length() + type.length() + 1 > msg_col_size) { + std::cout << std::endl << std::string(name_col_size, ' ') << type; + } else { + std::cout << " " << type; + } + std::cout << std::endl; + } + } + + void ShowUsage(const char* argv0) const { + for (auto p = argv0; *p; ++p) { + if (*p == '/' || *p == '\'') { + argv0 = p + 1; + } + } + std::cout << "Usage: " << argv0 << " [options]"; + for (const auto& info : infos_) { + if (!info.is_flag) { + std::cout << " " << info.name; + } + } + std::cout << std::endl; + } + + private: + std::vector infos_; +}; + +inline bool ParseArguments(int argc, char* argv[]) { return ArgParse::ParseArguments(argc, argv); } + +} // namespace utils + +#define _MMDEPLOY_DEFINE_FLAG(type, name, init, msg) \ + type FLAGS_##name = ::utils::ArgParse::Register(#type, #name, type(init), msg, &FLAGS_##name) + +#define _MMDEPLOY_DEFINE_ARG(type, name, msg) \ + type ARGS_##name = ::utils::ArgParse::Register(#type, #name, msg, &ARGS_##name) + +#endif // MMDEPLOY_ARGPARSE_H diff --git a/demo/csrc/cpp/utils/mediaio.h b/demo/csrc/cpp/utils/mediaio.h new file mode 100644 index 0000000000..3426dfb3e0 --- /dev/null +++ b/demo/csrc/cpp/utils/mediaio.h @@ -0,0 +1,388 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_MEDIAIO_H +#define MMDEPLOY_MEDIAIO_H + +#include +#include + +#include "opencv2/highgui/highgui.hpp" +#include "opencv2/imgcodecs/imgcodecs.hpp" +#include "opencv2/videoio/videoio.hpp" + +namespace utils { +namespace mediaio { + +enum class MediaType { kUnknown, kImage, kVideo, kImageList, kWebcam, kFmtStr, kDisable }; + +namespace detail { + +static std::string get_extension(const std::string& path) { + std::string ext; + for (auto i = (int)path.size() - 1; i >= 0; --i) { + if (path[i] == '.') { + ext.push_back(path[i]); + for (++i; i < path.size(); ++i) { + ext.push_back((char)std::tolower((unsigned char)path[i])); + } + return ext; + } + } + return {}; +} + +int ext2fourcc(const std::string& ext) { + auto get_fourcc = [](const char* s) { return cv::VideoWriter::fourcc(s[0], s[1], s[2], s[3]); }; + static std::map ext2fourcc{ + {".mp4", get_fourcc("mp4v")}, + {".avi", get_fourcc("DIVX")}, + {".mkv", get_fourcc("X264")}, + {".wmv", get_fourcc("WMV3")}, + }; + auto it = ext2fourcc.find(ext); + if (it != ext2fourcc.end()) { + return it->second; + } + return get_fourcc("DIVX"); +} + +static bool is_video(const std::string& ext) { + static const std::set es{".mp4", ".avi", ".mkv", ".webm", ".mov", ".mpg", ".wmv"}; + return es.count(ext); +} + +static bool is_list(const std::string& ext) { + static const std::set es{".txt"}; + return es.count(ext); +} + +static bool is_image(const std::string& ext) { + static const std::set es{".jpg", ".jpeg", ".png", ".tif", ".tiff", + ".bmp", ".ppm", ".pgm", ".webp"}; + return es.count(ext); +} + +static bool is_fmtstr(const std::string& str) { + for (const auto& c : str) { + if (c == '%') { + return true; + } + } + return false; +} + +} // namespace detail + +class Input; + +class InputIterator { + public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using reference = cv::Mat&; + using value_type = reference; + using pointer = void; + + public: + InputIterator() = default; + explicit InputIterator(Input& input) : input_(&input) { next(); } + InputIterator& operator++() { + next(); + return *this; + } + reference operator*() { return frame_; } + friend bool operator==(const InputIterator& a, const InputIterator& b) { + return &a == &b || a.is_end() == b.is_end(); + } + friend bool operator!=(const InputIterator& a, const InputIterator& b) { return !(a == b); } + + private: + void next(); + bool is_end() const noexcept { return frame_.data != nullptr; } + + private: + cv::Mat frame_; + Input* input_{}; +}; + +class BatchInputIterator { + public: + using iterator_category = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using reference = std::vector&; + using value_type = reference; + using pointer = void; + + public: + BatchInputIterator() = default; + BatchInputIterator(InputIterator iter, InputIterator end, size_t batch_size) + : iter_(std::move(iter)), end_(std::move(end)), batch_size_(batch_size) { + next(); + } + + BatchInputIterator& operator++() { + next(); + return *this; + } + + reference operator*() { return data_; } + + friend bool operator==(const BatchInputIterator& a, const BatchInputIterator& b) { + return &a == &b || a.is_end() == b.is_end(); + } + + friend bool operator!=(const BatchInputIterator& a, const BatchInputIterator& b) { + return !(a == b); + } + + private: + void next() { + data_.clear(); + for (size_t i = 0; i < batch_size_ && iter_ != end_; ++i, ++iter_) { + data_.push_back(*iter_); + } + } + + bool is_end() const { return data_.empty(); } + + private: + InputIterator iter_; + InputIterator end_; + size_t batch_size_{1}; + std::vector data_; +}; + +class Input { + public: + explicit Input(const std::string& path, MediaType type = MediaType::kUnknown) + : path_(path), type_(type) { + if (type_ == MediaType::kUnknown) { + auto ext = detail::get_extension(path); + if (detail::is_image(ext)) { + type_ = MediaType::kImage; + } else if (detail::is_video(ext)) { + type_ = MediaType::kVideo; + } else if (path.size() == 1 && std::isdigit((unsigned char)path[0])) { + type_ = MediaType::kWebcam; + } else if (detail::is_list(ext) || try_image_list(path)) { + type_ = MediaType::kImageList; + } else if (try_image(path)) { + type_ = MediaType::kImage; + } else if (try_video(path)) { + type_ = MediaType::kVideo; + } else { + std::cout << "unknown file type: " << path << "\n"; + } + } + if (type_ != MediaType::kUnknown) { + if (type_ == MediaType::kVideo) { + cap_.open(path_); + if (!cap_.isOpened()) { + std::cerr << "failed to open video file: " << path_ << "\n"; + } + } else if (type_ == MediaType::kWebcam) { + cap_.open(std::stoi(path_)); + if (!cap_.isOpened()) { + std::cerr << "failed to open camera index: " << path_ << "\n"; + } + type_ = MediaType::kVideo; + } else if (type_ == MediaType::kImage) { + items_ = {path_}; + type_ = MediaType::kImageList; + } else if (type_ == MediaType::kImageList) { + if (items_.empty()) { + items_ = load_image_list(path); + } + } + } + } + InputIterator begin() { return InputIterator(*this); } + InputIterator end() { return {}; } // NOLINT + + cv::Mat read() { + cv::Mat img; + if (type_ == MediaType::kVideo) { + cap_ >> img; + } else if (type_ == MediaType::kImageList) { + while (!img.data && index_ < items_.size()) { + auto path = items_[index_++]; + img = cv::imread(path); + if (!img.data) { + std::cerr << "failed to load image: " << path << "\n"; + } + } + } + return img; + } + + class Batch { + public: + Batch(Input& input, size_t batch_size) : input_(&input), batch_size_(batch_size) {} + BatchInputIterator begin() { return {input_->begin(), input_->end(), batch_size_}; } + BatchInputIterator end() { return {}; } // NOLINT + + private: + Input* input_{}; + size_t batch_size_{1}; + }; + + Batch batch(size_t batch_size) { return {*this, batch_size}; } + + private: + static bool try_image(const std::string& path) { return cv::imread(path).data; } + + static bool try_video(const std::string& path) { return cv::VideoCapture(path).isOpened(); } + + static std::vector load_image_list(const std::string& path, size_t max_bytes = 0) { + std::ifstream ifs(path); + ifs.seekg(0, std::ifstream::end); + auto size = ifs.tellg(); + ifs.seekg(0, std::ifstream::beg); + if (max_bytes && size > max_bytes) { + return {}; + } + auto strip = [](std::string& s) { + while (!s.empty() && std::isspace((unsigned char)s.back())) { + s.pop_back(); + } + }; + std::vector ret; + std::string line; + while (std::getline(ifs, line)) { + strip(line); + if (!line.empty()) { + ret.push_back(std::move(line)); + } + } + return ret; + } + + bool try_image_list(const std::string& path) { + auto items = load_image_list(path, 1 << 20); + size_t count = 0; + for (const auto& item : items) { + if (detail::is_image(detail::get_extension(item)) && ++count > items.size() / 10) { + items_ = std::move(items); + return true; + } + } + return false; + } + + private: + MediaType type_{MediaType::kUnknown}; + std::string path_; + std::vector items_; + cv::VideoCapture cap_; + size_t index_{}; +}; + +inline void InputIterator::next() { + assert(input_); + frame_ = input_->read(); +} + +class Output; + +class OutputIterator { + public: + using iterator_category = std::output_iterator_tag; + using difference_type = std::ptrdiff_t; + using reference = void; + using value_type = void; + using pointer = void; + + public: + explicit OutputIterator(Output& output) : output_(&output) {} + + OutputIterator& operator=(const cv::Mat& frame); + + OutputIterator& operator*() { return *this; } + OutputIterator& operator++() { return *this; } + OutputIterator& operator++(int) { return *this; } // NOLINT + + private: + Output* output_{}; +}; + +class Output { + public: + explicit Output(const std::string& path, int show, MediaType type = MediaType::kUnknown) + : path_(path), type_(type), show_(show) { + ext_ = detail::get_extension(path); + if (type_ == MediaType::kUnknown) { + if (path_.empty()) { + type_ = MediaType::kDisable; + } else if (detail::is_image(ext_)) { + if (detail::is_fmtstr(path)) { + type_ = MediaType::kFmtStr; + } else { + type_ = MediaType::kImage; + } + } else if (detail::is_video(ext_)) { + type_ = MediaType::kVideo; + } else { + std::cout << "unknown file type: " << path << "\n"; + } + } + } + + bool write(const cv::Mat& frame) { + bool exit = false; + switch (type_) { + case MediaType::kDisable: + break; + case MediaType::kImage: + cv::imwrite(path_, frame); + break; + case MediaType::kFmtStr: { + char buf[256]; + snprintf(buf, sizeof(buf), path_.c_str(), frame_id_); + cv::imwrite(buf, frame); + break; + } + case MediaType::kVideo: + write_video(frame); + break; + default: + std::cout << "unsupported output media type\n"; + assert(0); + } + if (show_ >= 0) { + cv::imshow("", frame); + exit = cv::waitKey(show_) == 'q'; + } + ++frame_id_; + return !exit; + } + + OutputIterator inserter() { return OutputIterator{*this}; } + + private: + void write_video(const cv::Mat& frame) { + if (!video_.isOpened()) { + open_video(frame.size()); + } + video_ << frame; + } + + void open_video(const cv::Size& size) { video_.open(path_, detail::ext2fourcc(ext_), 30, size); } + + private: + std::string path_; + std::string ext_; + MediaType type_{MediaType::kUnknown}; + int show_{1}; + size_t frame_id_{0}; + cv::VideoWriter video_; +}; + +OutputIterator& OutputIterator::operator=(const cv::Mat& frame) { + assert(output_); + output_->write(frame); + return *this; +} + +} // namespace mediaio +} // namespace utils +#endif // MMDEPLOY_MEDIAIO_H diff --git a/demo/csrc/cpp/utils/palette.h b/demo/csrc/cpp/utils/palette.h new file mode 100644 index 0000000000..03d76a4360 --- /dev/null +++ b/demo/csrc/cpp/utils/palette.h @@ -0,0 +1,94 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_PALETTE_H +#define MMDEPLOY_PALETTE_H + +#include +#include +#include +#include +#include +#include + +namespace utils { + +struct Palette { + std::vector data; + static Palette get(const std::string& path); + static Palette get(int n); +}; + +inline Palette Palette::get(const std::string& path) { + if (path == "coco") { + Palette p{{ + {220, 20, 60}, {119, 11, 32}, {0, 0, 142}, {0, 0, 230}, {106, 0, 228}, + {0, 60, 100}, {0, 80, 100}, {0, 0, 70}, {0, 0, 192}, {250, 170, 30}, + {100, 170, 30}, {220, 220, 0}, {175, 116, 175}, {250, 0, 30}, {165, 42, 42}, + {255, 77, 255}, {0, 226, 252}, {182, 182, 255}, {0, 82, 0}, {120, 166, 157}, + {110, 76, 0}, {174, 57, 255}, {199, 100, 0}, {72, 0, 118}, {255, 179, 240}, + {0, 125, 92}, {209, 0, 151}, {188, 208, 182}, {0, 220, 176}, {255, 99, 164}, + {92, 0, 73}, {133, 129, 255}, {78, 180, 255}, {0, 228, 0}, {174, 255, 243}, + {45, 89, 255}, {134, 134, 103}, {145, 148, 174}, {255, 208, 186}, {197, 226, 255}, + {171, 134, 1}, {109, 63, 54}, {207, 138, 255}, {151, 0, 95}, {9, 80, 61}, + {84, 105, 51}, {74, 65, 105}, {166, 196, 102}, {208, 195, 210}, {255, 109, 65}, + {0, 143, 149}, {179, 0, 194}, {209, 99, 106}, {5, 121, 0}, {227, 255, 205}, + {147, 186, 208}, {153, 69, 1}, {3, 95, 161}, {163, 255, 0}, {119, 0, 170}, + {0, 182, 199}, {0, 165, 120}, {183, 130, 88}, {95, 32, 0}, {130, 114, 135}, + {110, 129, 133}, {166, 74, 118}, {219, 142, 185}, {79, 210, 114}, {178, 90, 62}, + {65, 70, 15}, {127, 167, 115}, {59, 105, 106}, {142, 108, 45}, {196, 172, 0}, + {95, 54, 80}, {128, 76, 255}, {201, 57, 1}, {246, 0, 122}, {191, 162, 208}, + }}; + for (auto& x : p.data) { + std::swap(x[0], x[2]); + } + return p; + } else if (path == "cityscapes") { + Palette p{{ + {128, 64, 128}, {244, 35, 232}, {70, 70, 70}, {102, 102, 156}, {190, 153, 153}, + {153, 153, 153}, {250, 170, 30}, {220, 220, 0}, {107, 142, 35}, {152, 251, 152}, + {70, 130, 180}, {220, 20, 60}, {255, 0, 0}, {0, 0, 142}, {0, 0, 70}, + {0, 60, 100}, {0, 80, 100}, {0, 0, 230}, {119, 11, 32}, + }}; + for (auto& x : p.data) { + std::swap(x[0], x[2]); + } + return p; + } + std::ifstream ifs(path); + if (!ifs.is_open()) { + std::cout << "error: failed to open palette data file: " << path << "\n"; + std::abort(); + } + Palette p; + int n = 0; + ifs >> n; + for (int i = 0; i < n; ++i) { + cv::Vec3b x{}; + ifs >> x[0] >> x[1] >> x[2]; + p.data.push_back(x); + } + return p; +} + +inline Palette Palette::get(int n) { + std::vector samples(n * 100); + std::vector indices(samples.size()); + std::iota(indices.begin(), indices.end(), 0); + std::mt19937 gen; // NOLINT + std::uniform_int_distribution uniform_dist(0, 255); + for (auto& x : samples) { + x = {(float)uniform_dist(gen), (float)uniform_dist(gen), (float)uniform_dist(gen)}; + } + std::vector centers; + cv::kmeans(samples, n, indices, cv::TermCriteria(cv::TermCriteria::Type::COUNT, 10, 0), 1, + cv::KMEANS_PP_CENTERS, centers); + Palette p; + for (const auto& c : centers) { + p.data.emplace_back((uchar)c.x, (uchar)c.y, (uchar)c.z); + } + return p; +} + +} // namespace utils + +#endif // MMDEPLOY_PALETTE_H diff --git a/demo/csrc/cpp/utils/skeleton.h b/demo/csrc/cpp/utils/skeleton.h new file mode 100644 index 0000000000..59b0df415a --- /dev/null +++ b/demo/csrc/cpp/utils/skeleton.h @@ -0,0 +1,89 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_SKELETON_H +#define MMDEPLOY_SKELETON_H + +#include +#include +#include +#include +#include + +namespace utils { + +struct Skeleton { + std::vector> links; + std::vector palette; + std::vector link_colors; + std::vector point_colors; + static Skeleton get(const std::string& path); +}; + +const Skeleton& gCocoSkeleton() { + static const Skeleton inst{ + { + {15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12}, + {5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1}, + {0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6}, + }, + { + {255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255}, + {153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255}, + {255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102}, + {51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255}, + }, + {0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16}, + {16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0}, + }; + return inst; +} + +// n_links +// u0, v0, u1, v1, ..., un-1, vn-1 +// n_palette +// b0, g0, r0, ..., bn-1, gn-1, rn-1 +// n_link_color +// i0, i1, ..., in-1 +// n_point_color +// j0, j1, ..., jn-1 +inline Skeleton Skeleton::get(const std::string& path) { + if (path == "coco") { + return gCocoSkeleton(); + } + std::ifstream ifs(path); + if (!ifs.is_open()) { + std::cout << "error: failed to open skeleton data file: " << path << "\n"; + std::abort(); + } + Skeleton skel; + int n = 0; + ifs >> n; + for (int i = 0; i < n; ++i) { + int u{}, v{}; + ifs >> u >> v; + skel.links.emplace_back(u, v); + } + ifs >> n; + for (int i = 0; i < n; ++i) { + int b{}, g{}, r{}; + ifs >> b >> g >> r; + skel.palette.emplace_back(b, g, r); + } + ifs >> n; + for (int i = 0; i < n; ++i) { + int x{}; + ifs >> x; + skel.link_colors.push_back(x); + } + ifs >> n; + for (int i = 0; i < n; ++i) { + int x{}; + ifs >> x; + skel.point_colors.push_back(x); + } + return skel; +} + +} // namespace utils + +#endif // MMDEPLOY_SKELETON_H diff --git a/demo/csrc/cpp/utils/visualize.h b/demo/csrc/cpp/utils/visualize.h new file mode 100644 index 0000000000..e9d8493f58 --- /dev/null +++ b/demo/csrc/cpp/utils/visualize.h @@ -0,0 +1,252 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_VISUALIZE_H +#define MMDEPLOY_VISUALIZE_H + +#include +#include +#include +#include + +#include "opencv2/highgui/highgui.hpp" +#include "opencv2/imgproc/imgproc.hpp" +#include "palette.h" +#include "skeleton.h" + +namespace utils { + +class Visualize { + public: + class Session { + public: + explicit Session(Visualize& v, const cv::Mat& frame) : v_(v) { + if (v_.size_) { + scale_ = (float)v_.size_ / (float)std::max(frame.cols, frame.rows); + } + cv::Mat img; + if (v.background_ == "black") { + img = cv::Mat::zeros(frame.size(), CV_8UC3); + } else { + img = frame; + if (img.channels() == 1) { + cv::cvtColor(img, img, cv::COLOR_GRAY2BGR); + } + } + if (scale_ != 1) { + cv::resize(img, img, {}, scale_, scale_); + } else if (img.data == frame.data) { + img = img.clone(); + } + img_ = std::move(img); + } + + void add_label(int label_id, float score, int index) { + printf("label: %d, label_id: %d, score: %.4f\n", index, label_id, score); + auto size = .5f * static_cast(img_.rows + img_.cols); + offset_ += add_text(to_text(label_id, score), {1, (float)offset_}, size) + 2; + } + + int add_text(const std::string& text, const cv::Point2f& origin, float size) { + static constexpr const int font_face = cv::FONT_HERSHEY_SIMPLEX; + static constexpr const int thickness = 1; + static constexpr const auto max_font_scale = .5f; + static constexpr const auto min_font_scale = .25f; + float font_scale{}; + if (size < 20) { + font_scale = min_font_scale; + } else if (size > 200) { + font_scale = max_font_scale; + } else { + font_scale = min_font_scale + (size - 20) / (200 - 20) * (max_font_scale - min_font_scale); + } + int baseline{}; + auto text_size = cv::getTextSize(text, font_face, font_scale, thickness, &baseline); + cv::Rect rect(origin + cv::Point2f(0, text_size.height + 2 * thickness), + origin + cv::Point2f(text_size.width, 0)); + rect &= cv::Rect({}, img_.size()); + img_(rect) *= .35f; + cv::putText(img_, text, origin + cv::Point2f(0, text_size.height), font_face, font_scale, + cv::Scalar::all(255), thickness, cv::LINE_AA); + return text_size.height; + } + + static std::string to_text(int label_id, float score) { + std::stringstream ss; + ss << label_id << ": " << std::fixed << std::setprecision(1) << score * 100; + return ss.str(); + } + + template + void add_det(const mmdeploy_rect_t& rect, int label_id, float score, const Mask* mask, + int index) { + printf("bbox %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n", index, + rect.left, rect.top, rect.right, rect.bottom, label_id, score); + if (mask) { + fprintf(stdout, "mask %d, height=%d, width=%d\n", index, mask->height, mask->width); + auto x0 = (int)std::max(std::floor(rect.left) - 1, 0.f); + auto y0 = (int)std::max(std::floor(rect.top) - 1, 0.f); + add_instance_mask({x0, y0}, rand(), mask->data, mask->height, mask->width); + } + add_bbox(rect, label_id, score); + } + + void add_instance_mask(const cv::Point& origin, int color_id, const char* mask_data, int mask_h, + int mask_w, float alpha = .5f) { + auto color = v_.palette_.data[color_id % v_.palette_.data.size()]; + auto x_end = std::min(origin.x + mask_w, img_.cols); + auto y_end = std::min(origin.y + mask_h, img_.rows); + auto img_data = img_.ptr(); + for (int i = origin.y; i < y_end; ++i) { + for (int j = origin.x; j < x_end; ++j) { + if (mask_data[(i - origin.y) * mask_w + (j - origin.x)]) { + img_data[i * img_.cols + j] = img_data[i * img_.cols + j] * (1 - alpha) + color * alpha; + } + } + } + } + + void add_bbox(mmdeploy_rect_t rect, int label_id, float score) { + rect.left *= scale_; + rect.right *= scale_; + rect.top *= scale_; + rect.bottom *= scale_; + if (label_id >= 0 && score > 0) { + auto area = std::max(0.f, (rect.right - rect.left) * (rect.bottom - rect.top)); + add_text(to_text(label_id, score), {rect.left, rect.top}, std::sqrt(area)); + } + cv::rectangle(img_, cv::Point2f(rect.left, rect.top), cv::Point2f(rect.right, rect.bottom), + cv::Scalar(0, 255, 0)); + } + + void add_text_det(mmdeploy_point_t bbox[4], float score, const char* text, size_t text_size, + int index) { + printf("bbox[%d]: (%.2f, %.2f), (%.2f, %.2f), (%.2f, %.2f), (%.2f, %.2f)\n", index, // + bbox[0].x, bbox[0].y, // + bbox[1].x, bbox[1].y, // + bbox[2].x, bbox[2].y, // + bbox[3].x, bbox[3].y); + std::vector poly_points; + cv::Point2f center{}; + for (int i = 0; i < 4; ++i) { + poly_points.emplace_back(bbox[i].x * scale_, bbox[i].y * scale_); + center += cv::Point2f(poly_points.back()); + } + cv::polylines(img_, poly_points, true, cv::Scalar{0, 255, 0}, 1, cv::LINE_AA); + if (text) { + auto area = cv::contourArea(poly_points); + fprintf(stdout, "text[%d]: %s\n", index, text); + add_text(std::string(text, text + text_size), center / 4, std::sqrt(area)); + } + } + + void add_rotated_det(const float bbox[5], int label_id, float score) { + float xc = bbox[0] * scale_; + float yc = bbox[1] * scale_; + float w = bbox[2] * scale_; + float h = bbox[3] * scale_; + float ag = bbox[4]; + float wx = w / 2 * std::cos(ag); + float wy = w / 2 * std::sin(ag); + float hx = -h / 2 * std::sin(ag); + float hy = h / 2 * std::cos(ag); + cv::Point2f p1{xc - wx - hx, yc - wy - hy}; + cv::Point2f p2{xc + wx - hx, yc + wy - hy}; + cv::Point2f p3{xc + wx + hx, yc + wy + hy}; + cv::Point2f p4{xc - wx + hx, yc - wy + hy}; + cv::Point2f c = .25f * (p1 + p2 + p3 + p4); + cv::drawContours( + img_, + std::vector>{{p1 * scale_, p2 * scale_, p3 * scale_, p4 * scale_}}, + -1, {0, 255, 0}, 2, cv::LINE_AA); + add_text(to_text(label_id, score), c, std::sqrt(w * h)); + } + + void add_mask(int height, int width, int n_classes, const int* mask, const float* score) { + cv::Mat color_mask = cv::Mat::zeros(height, width, CV_8UC3); + auto n_pix = color_mask.total(); + + // compute top 1 idx if score (CHW) is available + cv::Mat_ top; + if (!mask && score) { + top = cv::Mat_::zeros(height, width); + for (auto c = 1; c < n_classes; ++c) { + top.forEach([&](int& x, const int* idx) { + auto offset = idx[0] * width + idx[1]; + if (score[c * n_pix + offset] > score[x * n_pix + offset]) { + x = c; + } + }); + } + mask = top.ptr(); + } + + if (mask) { + // palette look-up + color_mask.forEach([&](cv::Vec3b& x, const int* idx) { + auto& palette = v_.palette_.data; + x = palette[mask[idx[0] * width + idx[1]] % palette.size()]; + }); + + if (color_mask.size() != img_.size()) { + cv::resize(color_mask, color_mask, img_.size()); + } + + // blend mask and background image + cv::addWeighted(img_, .5, color_mask, .5, 0., img_); + } + } + + void add_pose(const mmdeploy_point_t* pts, const float* scores, int32_t pts_size, double thr) { + auto& skel = v_.skeleton_; + std::vector used(pts_size); + std::vector is_end_point(pts_size); + for (size_t i = 0; i < skel.links.size(); ++i) { + auto u = skel.links[i].first; + auto v = skel.links[i].second; + is_end_point[u] = is_end_point[v] = 1; + if (scores[u] > thr && scores[v] > thr) { + used[u] = used[v] = 1; + cv::Point2f p0(pts[u].x, pts[u].y); + cv::Point2f p1(pts[v].x, pts[v].y); + cv::line(img_, p0 * scale_, p1 * scale_, skel.palette[skel.link_colors[i]], 1, + cv::LINE_AA); + } + } + for (size_t i = 0; i < pts_size; ++i) { + if (!is_end_point[i] && scores[i] > thr || used[i]) { + cv::Point2f p(pts[i].x, pts[i].y); + cv::circle(img_, p * scale_, 1, skel.palette[skel.point_colors[i]], 2, cv::LINE_AA); + } + } + } + + cv::Mat get() { return img_; } + + private: + Visualize& v_; + float scale_{1}; + int offset_{1}; + cv::Mat img_; + }; + + explicit Visualize(int size = 0) : size_(size) { palette_ = Palette::get(32); } + + Session get_session(const cv::Mat& frame) { return Session(*this, frame); } + + void set_skeleton(const Skeleton& skeleton) { skeleton_ = skeleton; } + + void set_palette(const Palette& palette) { palette_ = palette; } + + void set_background(const std::string& background) { background_ = background; } + + private: + friend Session; + Skeleton skeleton_; + Palette palette_; + std::string background_; + int size_{}; +}; + +} // namespace utils + +#endif // MMDEPLOY_VISUALIZE_H diff --git a/demo/csrc/cpp/video_cls.cxx b/demo/csrc/cpp/video_cls.cxx index 95c9ef239e..ce08a50759 100644 --- a/demo/csrc/cpp/video_cls.cxx +++ b/demo/csrc/cpp/video_cls.cxx @@ -3,8 +3,8 @@ #include #include "mmdeploy/video_recognizer.hpp" -#include "opencv2/imgcodecs/imgcodecs.hpp" #include "opencv2/videoio.hpp" +#include "utils/argparse.h" void SampleFrames(const char* video_path, std::map& buffer, std::vector& clips, int clip_len, int frame_interval = 1, @@ -57,14 +57,14 @@ void SampleFrames(const char* video_path, std::map& buffer, } } +DEFINE_ARG_string(model, "Model path"); +DEFINE_ARG_string(video, "Input video path"); +DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")"); + int main(int argc, char* argv[]) { - if (argc != 4) { - fprintf(stderr, "usage:\n video_cls device_name model_path video_path\n"); - return 1; + if (!utils::ParseArguments(argc, argv)) { + return -1; } - auto device_name = argv[1]; - auto model_path = argv[2]; - auto video_path = argv[3]; int clip_len = 1; int frame_interval = 1; @@ -73,10 +73,10 @@ int main(int argc, char* argv[]) { std::map buffer; std::vector clips; mmdeploy::VideoSampleInfo clip_info = {clip_len, num_clips}; - SampleFrames(video_path, buffer, clips, clip_len, frame_interval, num_clips); + SampleFrames(ARGS_video.c_str(), buffer, clips, clip_len, frame_interval, num_clips); - mmdeploy::Model model(model_path); - mmdeploy::VideoRecognizer recognizer(model, mmdeploy::Device{device_name, 0}); + mmdeploy::Model model(ARGS_model); + mmdeploy::VideoRecognizer recognizer(model, mmdeploy::Device{FLAGS_device}); auto res = recognizer.Apply(clips, clip_info);