Skip to content

Commit

Permalink
#9 fixed YOLOX inference error for non-square shape
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth committed Jul 31, 2021
1 parent fd92a3c commit 4150622
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 7 deletions.
4 changes: 3 additions & 1 deletion hub/onnx/cv/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,6 @@ PFLD-pytorch-pfld.onnx
pytorch_face_landmarks_landmark_detection_56.onnx
pytorch_face_landmarks_landmark_detection_56_se_external.onnx
pytorch_face_landmarks_pfld.onnx
FaceLandmark1000.onnx
FaceLandmark1000.onnx
Pytorch_RetinaFace_mobile0.25.onnx
Pytorch_RetinaFace_resnet50.onnx
9 changes: 8 additions & 1 deletion lite/lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#include "ort/cv/mobilenetv2_68.h"
#include "ort/cv/mobilenetv2_se_68.h"
#include "ort/cv/face_landmarks_1000.h"
#include "ort/cv/retinaface.h"

#endif

Expand Down Expand Up @@ -144,6 +145,7 @@ namespace lite
typedef ortcv::MobileNetV268 _MobileNetV268;
typedef ortcv::MobileNetV2SE68 _MobileNetV2SE68;
typedef ortcv::FaceLandmark1000 _FaceLandmark1000;
typedef ortcv::RetinaFace _RetinaFace;
#endif

// 1. classification
Expand Down Expand Up @@ -197,11 +199,14 @@ namespace lite
typedef _MobileNetV268 MobileNetV268;
typedef _MobileNetV2SE68 MobileNetV2SE68;
typedef _FaceLandmark1000 FaceLandmark1000;
typedef _RetinaFace RetinaFace;
#endif
namespace detect
{
#ifdef BACKEND_ONNXRUNTIME
typedef _UltraFace UltraFace; // face detection.
typedef _RetinaFace RetinaFace;

#endif
}

Expand Down Expand Up @@ -401,7 +406,7 @@ namespace lite
typedef ortcv::MobileNetV268 _ONNXMobileNetV268;
typedef ortcv::MobileNetV2SE68 _ONNXMobileNetV2SE68;
typedef ortcv::FaceLandmark1000 _ONNXFaceLandmark1000;

typedef ortcv::RetinaFace _ONNXRetinaFace;

// 1. classification
namespace classification
Expand Down Expand Up @@ -449,10 +454,12 @@ namespace lite
typedef _ONNXMobileNetV268 MobileNetV268;
typedef _ONNXMobileNetV2SE68 MobileNetV2SE68;
typedef _ONNXFaceLandmark1000 FaceLandmark1000;
typedef _ONNXRetinaFace _RetinaFace;

namespace detect
{
typedef _ONNXUltraFace UltraFace; // face detection.
typedef _ONNXRetinaFace _RetinaFace;
}

namespace align
Expand Down
7 changes: 3 additions & 4 deletions ort/core/ort_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,9 @@ namespace ortcv
class LITE_EXPORTS MobileNetV2SE68; // [51] * reference: https://github.com/cunjian/pytorch_face_landmark
class LITE_EXPORTS PFLD68; // [52] * reference: https://github.com/cunjian/pytorch_face_landmark
class LITE_EXPORTS FaceLandmark1000; // [53] * reference: https://github.com/Single430/FaceLandmark1000
class LITE_EXPORTS MobileV1RetinaFace; // [54] reference: https://github.com/biubug6/Pytorch_Retinaface
class LITE_EXPORTS ResNetRetinaFace; // [55] reference: https://github.com/biubug6/Pytorch_Retinaface
class LITE_EXPORTS FaceBoxes; // [56] reference: https://github.com/zisianw/FaceBoxes.PyTorch
class LITE_EXPORTS YoloX; // [57] * reference: https://github.com/Megvii-BaseDetection/YOLOX
class LITE_EXPORTS RetinaFace; // [54] reference: https://github.com/biubug6/Pytorch_Retinaface
class LITE_EXPORTS FaceBoxes; // [55] reference: https://github.com/zisianw/FaceBoxes.PyTorch
class LITE_EXPORTS YoloX; // [56] * reference: https://github.com/Megvii-BaseDetection/YOLOX
}

namespace ortnlp
Expand Down
162 changes: 162 additions & 0 deletions ort/cv/retinaface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
//
// Created by DefTruth on 2021/7/31.
//

#include "retinaface.h"
#include "ort/core/ort_utils.h"

using ortcv::RetinaFace;

Ort::Value RetinaFace::transform(const cv::Mat &mat)
{
cv::Mat canva = mat.clone();
cv::resize(canva, canva, cv::Size(input_node_dims.at(3),
input_node_dims.at(2)));
// (1,3,640,640) 1xCXHXW

ortcv::utils::transform::normalize_inplace(canva, mean_vals, scale_vals); // float32
return ortcv::utils::transform::create_tensor(
canva, input_node_dims, memory_info_handler,
input_values_handler, ortcv::utils::transform::CHW);
}

void RetinaFace::detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_boxes,
float score_threshold, float iou_threshold,
unsigned int topk, unsigned int nms_type)
{
if (mat.empty()) return;
// this->transform(mat);
float img_height = static_cast<float>(mat.rows);
float img_width = static_cast<float>(mat.cols);

// 1. make input tensor
Ort::Value input_tensor = this->transform(mat);
// 2. inference scores & boxes.
auto output_tensors = ort_session->Run(
Ort::RunOptions{nullptr}, input_node_names.data(),
&input_tensor, 1, output_node_names.data(), num_outputs
);
// 3. rescale & exclude.
std::vector<types::Boxf> bbox_collection;
this->generate_bboxes(bbox_collection, output_tensors, score_threshold, img_height, img_width);
// 4. hard|blend nms with topk.
this->nms(bbox_collection, detected_boxes, iou_threshold, topk, nms_type);
}

// ref: https://github.com/biubug6/Pytorch_Retinaface/blob/master/layers/functions/prior_box.py
void RetinaFace::generate_anchors(const int target_height,
const int target_width,
std::vector<RetinaAnchor> &anchors)
{
std::vector<std::vector<int>> feature_maps;
for (auto step: steps)
{
feature_maps.push_back(
{
std::ceilf((float) target_height / (float) step),
std::ceilf((float) target_width / (float) step)
} // ceil
);
}

anchors.clear();
const int num_feature_map = feature_maps.size();

for (int k = 0; k < num_feature_map; ++k)
{
auto f_map = feature_maps.at(k); // e.g [640//8,640//8]
auto tmp_min_sizes = min_sizes.at(k); // e.g [8,16]
int f_h = f_map.at(0), f_map = f_.at(1);
for (int i = 0; i < f_h; ++i)
{
for (int j = 0; j < f_w; ++j)
{
for (auto min_size: tmp_min_sizes)
{
float s_kx = (float) min_size / (float) target_width; // e.g 16/w
float s_ky = (float) min_size / (float) target_height; // e.g 16/h
// (x + 0.5) * step / w normalized loc mapping to input width
// (y + 0.5) * step / h normalized loc mapping to input height
float cx = ((float) j + 0.5f) * (float) steps.at(k) / (float) target_width;
float cy = ((float) i + 0.5f) * (float) steps.at(k) / (float) target_height;

anchors.push_back((RetinaAnchor) {cx, cy, s_kx, s_ky}); // without clip,
}
}
}
}
}


void RetinaFace::generate_bboxes(std::vector<types::Boxf> &bbox_collection,
std::vector<Ort::Value> &output_tensors,
float score_threshold,
float img_height, float img_width)
{
Ort::Value &bboxes = output_tensors.at(0); // e.g (1,16800,4)
Ort::Value &probs = output_tensors.at(1); // e.g (1,16800,2) after softmax
auto bbox_dims = output_node_dims.at(0); // (1,16800,4)
const unsigned int bbox_num = bbox_dims.at(1); // n = ?
const float input_height = static_cast<float>(input_node_dims.at(2)); // e.g 640
const float input_width = static_cast<float>(input_node_dims.at(3)); // e.g 640

std::vector<RetinaAnchor> anchors;
this->generate_anchors(input_height, input_width, anchors);

const unsigned int num_anchors = anchors.size();

if (num_anchors != bbox_num)
throw std::runtime_error("mismatch num_anchors != bbox_num");

bbox_collection.clear();
unsigned int count = 0;
for (unsigned int i = 0; i < num_anchors; ++i)
{
float conf = probs.At<float>({0, i, 1});
if (conf < score_threshold) continue; // filter first.

float prior_cx = anchors.at(i).cx;
float prior_cy = anchors.at(i).cy;
float prior_s_kx = anchors.at(i).s_kx;
float prior_s_ky = anchors.at(i).s_ky;

float dx = bboxes.At<float>({0, i, 0});
float dy = bboxes.At<float>({0, i, 1});
float dw = bboxes.At<float>({0, i, 2});
float dh = bboxes.At<float>({0, i, 3});

// ref: https://github.com/biubug6/Pytorch_Retinaface/blob/master/utils/box_utils.py
float cx = prior_cx + dx * variance[0] * prior_cx;
float cy = prior_cy + dy * variance[0] * prior_cy;
float w = prior_s_kx * std::expf(dw * variance[1]);
float h = prior_s_ky * std::expf(dh * variance[1]); // norm coor (0.,1.)

types::Boxf box;
box.x1 = (cx - w / 2.f) * img_width;
box.y1 = (cy - h / 2.f) * img_height;
box.x2 = (cx + w / 2.f) * img_width;
box.y2 = (cy + h / 2.f) * img_height;
box.score = conf;
box.label = 1;
box.label_text = "face";
box.flag = true;
bbox_collection.push_back(box);

count += 1; // limit boxes for nms.
if (count > max_nms)
break;
}
#if LITEORT_DEBUG
std::cout << "detected num_anchors: " << num_anchors << "\n";
std::cout << "generate_bboxes num: " << bbox_collection.size() << "\n";
#endif
}


void RetinaFace::nms(std::vector<types::Boxf> &input, std::vector<types::Boxf> &output,
float iou_threshold, unsigned int topk, unsigned int nms_type)
{
if (nms_type == NMS::BLEND) ortcv::utils::blending_nms(input, output, iou_threshold, topk);
else if (nms_type == NMS::OFFSET) ortcv::utils::offset_nms(input, output, iou_threshold, topk);
else ortcv::utils::hard_nms(input, output, iou_threshold, topk);
}
69 changes: 69 additions & 0 deletions ort/cv/retinaface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//
// Created by DefTruth on 2021/7/31.
//

#ifndef LITE_AI_ORT_CV_RETINAFACE_H
#define LITE_AI_ORT_CV_RETINAFACE_H

#include "ort/core/ort_core.h"

namespace ortcv
{

// reference: Pytorch_Retinaface python implementation.
// https://github.com/biubug6/Pytorch_Retinaface/blob/master/layers/functions/prior_box.py
struct RetinaAnchor
{
float cx;
float cy;
float s_kx;
float s_ky;
};

class LITE_EXPORTS RetinaFace : public BasicOrtHandler
{
public:
explicit RetinaFace(const std::string &_onnx_path, unsigned int _num_threads = 1) :
BasicOrtHandler(_onnx_path, _num_threads)
{};

~RetinaFace() override = default;

private:
const float mean_vals[3] = {104.f, 117.f, 123.f}; // bgr order
const float scale_vals[3] = {1.f, 1.f, 1.f};
const float variance[2] = {0.1f, 0.2f};
std::vector<int> steps = {8, 16, 32};
std::vector<std::vector<int>> min_sizes = {{16, 32}, {64, 128}, {256, 512}};

enum NMS
{
HARD = 0, BLEND = 1, OFFSET = 2
};
static constexpr const unsigned int max_nms = 30000;

private:
Ort::Value transform(const cv::Mat &mat) override; //

void generate_anchors(const int target_height,
const int target_width,
std::vector<RetinaAnchor> &anchors);

void generate_bboxes(std::vector<types::Boxf> &bbox_collection,
std::vector<Ort::Value> &output_tensors,
float score_threshold, float img_height,
float img_width); // rescale & exclude

void nms(std::vector<types::Boxf> &input, std::vector<types::Boxf> &output,
float iou_threshold, unsigned int topk, unsigned int nms_type);

public:
void detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_boxes,
float score_threshold = 0.25f, float iou_threshold = 0.45f,
unsigned int topk = 100, unsigned int nms_type = NMS::OFFSET);

};
}


#endif //LITE_AI_ORT_CV_RETINAFACE_H
1 change: 0 additions & 1 deletion ort/cv/ultraface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ void UltraFace::generate_bboxes(std::vector<types::Boxf> &bbox_collection,
Ort::Value &scores = output_tensors.at(0);
Ort::Value &boxes = output_tensors.at(1);
auto scores_dims = output_node_dims.at(0); // (1,n,2)
auto boxes_dims = output_node_names.at(1); // (1,n,4) x1,y1,x2,y2
const unsigned int num_anchors = scores_dims.at(1); // n = 17640 (640x480)

bbox_collection.clear();
Expand Down

0 comments on commit 4150622

Please sign in to comment.