Skip to content

Commit

Permalink
Add letterbox for TRT C++ inference (#305)
Browse files Browse the repository at this point in the history
* Add letter box

* format code

* modify label type to int

* rename function letterbox
  • Loading branch information
ShiquanYu authored Feb 9, 2022
1 parent ba6c637 commit be343c5
Showing 1 changed file with 97 additions and 35 deletions.
132 changes: 97 additions & 35 deletions deployment/tensorrt/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,85 @@ void visualizeDetection(
}
}

float letterbox(
const cv::Mat& src,
cv::Mat& dst,
int dst_size,
int align_size,
cv::Scalar fill,
const float dst_hw_scale = -1.f,
bool simple_mode = false) {
#define ALIGN_UP(val, alignment) (((val) + (alignment)-1) & (-alignment))
float scale = -1.f;
int input_h = -1;
int input_w = -1;
int lb_input_h = -1;
int lb_input_w = -1;

if (src.empty()) {
std::cerr << "Assert source image empty!" << std::endl;
return scale;
}

if (dst_hw_scale > 0) {
if (dst_hw_scale > 1) {
lb_input_h = dst_size;
lb_input_w = dst_size / dst_hw_scale;
if (align_size > 0) {
lb_input_w = ALIGN_UP(lb_input_w, align_size);
}
} else {
lb_input_w = dst_size;
lb_input_h = dst_size * dst_hw_scale;
if (align_size > 0) {
lb_input_h = ALIGN_UP(lb_input_h, align_size);
}
}
float scale_h = float(src.rows) / lb_input_h;
float scale_w = float(src.cols) / lb_input_w;
if (scale_w > scale_h) {
input_w = lb_input_w;
scale = scale_w;
input_h = src.rows / scale;
} else {
input_h = lb_input_h;
scale = scale_h;
input_w = src.cols / scale;
}
} else {
if (src.cols > src.rows) {
input_w = dst_size;
scale = float(src.cols) / dst_size;
input_h = src.rows / scale;
lb_input_w = dst_size;
lb_input_h = align_size > 0 ? ALIGN_UP(input_h, 64) : input_h;
} else {
input_h = dst_size;
scale = float(src.rows) / dst_size;
input_w = src.cols / scale;
lb_input_h = dst_size;
lb_input_w = align_size > 0 ? ALIGN_UP(input_w, 64) : input_w;
}
}
dst.create(lb_input_h, lb_input_w, CV_8UC3);
dst.setTo(fill);
{
cv::Mat rs_img{};
int start_x = 0;
int start_y = 0;
if (!simple_mode) {
start_x = (dst.cols - input_w) / 2;
start_y = (dst.rows - input_h) / 2;
}
cv::resize(src, rs_img, cv::Size(input_w, input_h));
cv::Rect roi_rect{start_x, start_y, rs_img.cols, rs_img.rows};
rs_img.copyTo(dst(roi_rect));
}

return scale;
#undef ALIGN_UP
}

std::vector<std::string> loadNames(const std::string& path) {
// load class names
std::vector<std::string> classNames;
Expand Down Expand Up @@ -163,32 +242,6 @@ ICudaEngine* CreateCudaEngineFromOnnx(
config->setDLACore(builder->getNbDLACores());
}

// TODO: dynamic input,还没想好怎么搞
// IOptimizationProfile* profile = builder->createOptimizationProfile();
// if (!profile) {
// return nullptr;
// }
//
// {
// Dims dim = network->getInput(0)->getDimensions();
// const char* name = network->getInput(0)->getName();
// profile->setDimensions(name, OptProfileSelector::kMIN, Dims4(1, dim.d[1], dim.d[2],
// dim.d[3])); profile->setDimensions(name, OptProfileSelector::kOPT, Dims4(1, dim.d[1],
// dim.d[2], dim.d[3])); profile->setDimensions(
// name,
// OptProfileSelector::kMAX,
// Dims4(builder->getMaxBatchSize(), dim.d[1], dim.d[2], dim.d[3]));
// // profile->setDimensions(name, OptProfileSelector::kMIN, Dims4(1, 3, 192, 320));
// // profile->setDimensions(name, OptProfileSelector::kOPT, Dims4(1, 3, 256, 416));
// // profile->setDimensions(name, OptProfileSelector::kMAX, Dims4(1, 3, 640, 640));
// if (profile->isValid()) {
// config->addOptimizationProfile(profile);
// } else {
// std::cout << "profile is invalid!\n" << std::endl;
// exit(-1);
// }
// }

std::unique_ptr<IHostMemory> serializedModel{
builder->buildSerializedNetwork(*network.get(), *config.get())};
if (!serializedModel) {
Expand Down Expand Up @@ -257,8 +310,8 @@ YOLOv5Detector::~YOLOv5Detector() {

std::vector<Detection> YOLOv5Detector::detect(cv::Mat& image) {
std::vector<Detection> result;
std::vector<void*> buffers(engine->getNbBindings());
std::size_t batch_size = 1;
void* buffers[engine->getNbBindings()];
int num_detections_index = engine->getBindingIndex("num_detections");
int detection_boxes_index = engine->getBindingIndex("detection_boxes");
int detection_scores_index = engine->getBindingIndex("detection_scores");
Expand Down Expand Up @@ -304,7 +357,14 @@ std::vector<Detection> YOLOv5Detector::detect(cv::Mat& image) {
int32_t input_h = engine->getBindingDimensions(0).d[2];
int32_t input_w = engine->getBindingDimensions(0).d[3];
cv::Mat tmp;
cv::resize(image, tmp, cv::Size(input_w, input_h));
/* Fixed shape, need to set h/w scale */
float scale = letterbox(
image,
tmp,
std::max(input_w, input_h),
-1,
cv::Scalar(114, 114, 114),
float(input_h) / input_w);
tmp.convertTo(tmp, CV_32FC3, 1 / 255.0);
{
/* HWC ==> CHW */
Expand All @@ -321,7 +381,7 @@ std::vector<Detection> YOLOv5Detector::detect(cv::Mat& image) {
offset = split_image.total();
}
}
context->enqueueV2(buffers, stream, nullptr);
context->enqueueV2(buffers.data(), stream, nullptr);

for (int32_t i = 1; i < engine->getNbBindings(); i++) {
if (i == detection_boxes_index) {
Expand Down Expand Up @@ -356,15 +416,17 @@ std::vector<Detection> YOLOv5Detector::detect(cv::Mat& image) {
}

/* Convert box fromat from LTRB to LTWH */
int x_offset = (input_w * scale - image.cols) / 2;
int y_offset = (input_h * scale - image.rows) / 2;
for (int32_t i = 0; i < num_detections; i++) {
Detection detection;
detection.box.x = detection_boxes[4 * i] * image.cols / input_w;
detection.box.y = detection_boxes[4 * i + 1] * image.rows / input_h;
detection.box.width = detection_boxes[4 * i + 2] * image.cols / input_w - detection.box.x;
detection.box.height = detection_boxes[4 * i + 3] * image.rows / input_h - detection.box.y;
result.emplace_back();
Detection& detection = result.back();
detection.box.x = detection_boxes[4 * i] * scale - x_offset;
detection.box.y = detection_boxes[4 * i + 1] * scale - y_offset;
detection.box.width = detection_boxes[4 * i + 2] * scale - x_offset - detection.box.x;
detection.box.height = detection_boxes[4 * i + 3] * scale - y_offset - detection.box.y;
detection.classId = detection_labels[i];
detection.conf = detection_scores[i];
result.push_back(detection);
}

return result;
Expand Down

0 comments on commit be343c5

Please sign in to comment.