Skip to content

Commit

Permalink
[TRT] support MODNet (#442)
Browse files Browse the repository at this point in the history
* add modnet trt  test code

* modnet trt implement

* update code

* add trt modnet
  • Loading branch information
wangzijian1010 authored Oct 28, 2024
1 parent 557521d commit 4252d27
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 5 deletions.
60 changes: 55 additions & 5 deletions examples/lite/cv/test_lite_modnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,55 @@ static void test_onnxruntime()
#endif
}



static void test_tensorrt()
{
#ifdef ENABLE_TENSORRT
std::string engine_path = "../../../examples/hub/trt/modnet_fp16.engine";
std::string test_img_path = "../../../examples/lite/resources/test_lite_matting_input.jpg";
std::string test_bgr_path = "../../../examples/lite/resources/test_lite_matting_bgr.jpg";
std::string save_fgr_path = "../../../examples/logs/test_lite_modnet_fgr_trt.jpg";
std::string save_pha_path = "../../../examples/logs/test_lite_modnet_pha_trt.jpg";
std::string save_merge_path = "../../../examples/logs/test_lite_modnet_merge_trt.jpg";
std::string save_swap_path = "../../../examples/logs/test_lite_modnet_swap_trt.jpg";


lite::trt::cv::matting::MODNet *modnet = new lite::trt::cv::matting::MODNet (engine_path);

lite::types::MattingContent content;
cv::Mat img_bgr = cv::imread(test_img_path);
cv::Mat bgr_mat = cv::imread(test_bgr_path);

// 1. image matting.
modnet->detect(img_bgr, content, true, true);

if (content.flag)
{
if (!content.fgr_mat.empty()) cv::imwrite(save_fgr_path, content.fgr_mat);
if (!content.pha_mat.empty()) cv::imwrite(save_pha_path, content.pha_mat * 255.);
if (!content.merge_mat.empty()) cv::imwrite(save_merge_path, content.merge_mat);
// swap background
cv::Mat out_mat;
if (!content.fgr_mat.empty())
lite::utils::swap_background(content.fgr_mat, content.pha_mat, bgr_mat, out_mat, true);
else
lite::utils::swap_background(img_bgr, content.pha_mat, bgr_mat, out_mat, false);

if (!out_mat.empty())
{
cv::imwrite(save_swap_path, out_mat);
std::cout << "Saved Swap Image Done!" << std::endl;
}

std::cout << "Default Version MGMatting Done!" << std::endl;
}

delete modnet;
#endif
}


static void test_mnn()
{
#ifdef ENABLE_MNN
Expand Down Expand Up @@ -233,11 +282,12 @@ static void test_tnn()

static void test_lite()
{
test_default();
test_onnxruntime();
test_mnn();
test_ncnn();
test_tnn();
// test_default();
// test_onnxruntime();
// test_mnn();
// test_ncnn();
// test_tnn();
test_tensorrt();
}

int main(__unused int argc, __unused char *argv[])
Expand Down
6 changes: 6 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
#include "lite/trt/cv/trt_yolox.h"
#include "lite/trt/cv/trt_yolov8.h"
#include "lite/trt/cv/trt_yolov6.h"
#include "lite/trt/cv/trt_modnet.h"
#include "lite/trt/cv/trt_yolov5_blazeface.h"
#include "lite/trt/cv/trt_lightenhance.h"
#include "lite/trt/cv/trt_realesrgan.h"
Expand Down Expand Up @@ -731,9 +732,14 @@ namespace lite{
typedef trtcv::TRTYOLO5Face _TRT_YOLO5Face;
typedef trtcv::TRTLightEnhance _TRT_LightEnhance;
typedef trtcv::TRTRealESRGAN _TRT_RealESRGAN;
typedef trtcv::TRTMODNet _TRT_MODNet;
namespace classification
{

}
namespace matting
{
typedef _TRT_MODNet MODNet;
}
namespace detection
{
Expand Down
42 changes: 42 additions & 0 deletions lite/trt/core/trt_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,46 @@ void trtcv::utils::transform::trt_generate_latents(std::vector<float> &latents,
for (size_t i = 0; i < total_size; ++i) {
latents[i] = dist(gen) * init_noise_sigma;
}
}

void trtcv::utils::remove_small_connected_area(cv::Mat &alpha_pred, float threshold) {
cv::Mat gray, binary;
alpha_pred.convertTo(gray, CV_8UC1, 255.f);
// 255 * 0.05 ~ 13
unsigned int binary_threshold = (unsigned int) (255.f * threshold);
// https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/util.py#L209
cv::threshold(gray, binary, binary_threshold, 255, cv::THRESH_BINARY);
// morphologyEx with OPEN operation to remove noise first.
auto kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(3, 3), cv::Point(-1, -1));
cv::morphologyEx(binary, binary, cv::MORPH_OPEN, kernel);
// Computationally connected domain
cv::Mat labels = cv::Mat::zeros(alpha_pred.size(), CV_32S);
cv::Mat stats, centroids;
int num_labels = cv::connectedComponentsWithStats(binary, labels, stats, centroids, 8, 4);
if (num_labels <= 1) return; // no noise, skip.
// find max connected area, 0 is background
int max_connected_id = 1; // 1,2,...
int max_connected_area = stats.at<int>(max_connected_id, cv::CC_STAT_AREA);
for (int i = 1; i < num_labels; ++i)
{
int tmp_connected_area = stats.at<int>(i, cv::CC_STAT_AREA);
if (tmp_connected_area > max_connected_area)
{
max_connected_area = tmp_connected_area;
max_connected_id = i;
}
}
const int h = alpha_pred.rows;
const int w = alpha_pred.cols;
// remove small connected area.
for (int i = 0; i < h; ++i)
{
int *label_row_ptr = labels.ptr<int>(i);
float *alpha_row_ptr = alpha_pred.ptr<float>(i);
for (int j = 0; j < w; ++j)
{
if (label_row_ptr[j] != max_connected_id)
alpha_row_ptr[j] = 0.f;
}
}
}
1 change: 1 addition & 0 deletions lite/trt/core/trt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace trtcv

LITE_EXPORTS void trt_generate_latents(std::vector<float>& latents, int batch_size, int unet_channels, int latent_height, int latent_width, float init_noise_sigma);
}
LITE_EXPORTS void remove_small_connected_area(cv::Mat &alpha_pred, float threshold);
}
}

Expand Down
146 changes: 146 additions & 0 deletions lite/trt/cv/trt_modnet.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
//
// Created by wangzijian on 10/28/24.
//

#include "trt_modnet.h"
using trtcv::TRTMODNet;

void TRTMODNet::preprocess(cv::Mat &input_mat) {
cv::Mat ori_input_mat = input_mat;
cv::resize(input_mat,input_mat,cv::Size(512,512));
cv::cvtColor(input_mat,input_mat,cv::COLOR_BGR2RGB);
if (input_mat.type() != CV_32FC3) input_mat.convertTo(input_mat, CV_32FC3);
else input_mat = input_mat;
input_mat = (input_mat -mean_val) * scale_val;

}



void TRTMODNet::detect(const cv::Mat &mat, types::MattingContent &content, bool remove_noise, bool minimum_post_process) {
if (mat.empty()) return;
cv::Mat preprocessed_mat = mat;
preprocess(preprocessed_mat);

const int batch_size = 1;
const int channels = 3;
const int input_h = preprocessed_mat.rows;
const int input_w = preprocessed_mat.cols;
const size_t input_size = batch_size * channels * input_h * input_w * sizeof(float);
const size_t output_size = batch_size * channels * input_h * input_w * sizeof(float);

for (auto& buffer : buffers) {
if (buffer) {
cudaFree(buffer);
buffer = nullptr;
}
}
cudaMalloc(&buffers[0], input_size);
cudaMalloc(&buffers[1], output_size);
if (!buffers[0] || !buffers[1]) {
std::cerr << "Failed to allocate CUDA memory" << std::endl;
return;
}

input_node_dims = {batch_size, channels, input_h, input_w};

std::vector<float> input;
trtcv::utils::transform::create_tensor(preprocessed_mat,input,input_node_dims,trtcv::utils::transform::CHW);

//3.infer
cudaMemcpyAsync(buffers[0], input.data(), input_size,
cudaMemcpyHostToDevice, stream);

nvinfer1::Dims MODNetDims;
MODNetDims.nbDims = 4;
MODNetDims.d[0] = batch_size;
MODNetDims.d[1] = channels;
MODNetDims.d[2] = input_h;
MODNetDims.d[3] = input_w;

auto input_tensor_name = trt_engine->getIOTensorName(0);
auto output_tensor_name = trt_engine->getIOTensorName(1);
trt_context->setTensorAddress(input_tensor_name, buffers[0]);
trt_context->setTensorAddress(output_tensor_name, buffers[1]);
trt_context->setInputShape(input_tensor_name, MODNetDims);

bool status = trt_context->enqueueV3(stream);
if (!status){
std::cerr << "Failed to infer by TensorRT." << std::endl;
return;
}



std::vector<float> output(batch_size * channels * input_h * input_w);
cudaMemcpyAsync(output.data(), buffers[1], output_size,
cudaMemcpyDeviceToHost, stream);

// post
generate_matting(output.data(),mat,content, remove_noise, minimum_post_process);
}

void TRTMODNet::generate_matting(float *trt_outputs, const cv::Mat &mat, types::MattingContent &content,
bool remove_noise, bool minimum_post_process) {

const unsigned int h = mat.rows;
const unsigned int w = mat.cols;


const unsigned int out_h = 512;
const unsigned int out_w = 512;

cv::Mat alpha_pred(out_h, out_w, CV_32FC1, trt_outputs);
cv::imwrite("/home/lite.ai.toolkit/modnet.jpg",alpha_pred);
// post process
if (remove_noise) trtcv::utils::remove_small_connected_area(alpha_pred,0.05f);
// resize alpha
if (out_h != h || out_w != w)
// already allocated a new continuous memory after resize.
cv::resize(alpha_pred, alpha_pred, cv::Size(w, h));
// need clone to allocate a new continuous memory if not performed resize.
// The memory elements point to will release after return.
else alpha_pred = alpha_pred.clone();

cv::Mat pmat = alpha_pred; // ref
content.pha_mat = pmat; // auto handle the memory inside ocv with smart ref.

if (!minimum_post_process)
{
// MODNet only predict Alpha, no fgr. So,
// the fake fgr and merge mat may not need,
// let the fgr mat and merge mat empty to
// Speed up the post processes.
cv::Mat mat_copy;
mat.convertTo(mat_copy, CV_32FC3);
// merge mat and fgr mat may not need
std::vector<cv::Mat> mat_channels;
cv::split(mat_copy, mat_channels);
cv::Mat bmat = mat_channels.at(0);
cv::Mat gmat = mat_channels.at(1);
cv::Mat rmat = mat_channels.at(2); // ref only, zero-copy.
bmat = bmat.mul(pmat);
gmat = gmat.mul(pmat);
rmat = rmat.mul(pmat);
cv::Mat rest = 1.f - pmat;
cv::Mat mbmat = bmat.mul(pmat) + rest * 153.f;
cv::Mat mgmat = gmat.mul(pmat) + rest * 255.f;
cv::Mat mrmat = rmat.mul(pmat) + rest * 120.f;
std::vector<cv::Mat> fgr_channel_mats, merge_channel_mats;
fgr_channel_mats.push_back(bmat);
fgr_channel_mats.push_back(gmat);
fgr_channel_mats.push_back(rmat);
merge_channel_mats.push_back(mbmat);
merge_channel_mats.push_back(mgmat);
merge_channel_mats.push_back(mrmat);

cv::merge(fgr_channel_mats, content.fgr_mat);
cv::merge(merge_channel_mats, content.merge_mat);

content.fgr_mat.convertTo(content.fgr_mat, CV_8UC3);
content.merge_mat.convertTo(content.merge_mat, CV_8UC3);
}

content.flag = true;

}
34 changes: 34 additions & 0 deletions lite/trt/cv/trt_modnet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//
// Created by wangzijian on 10/28/24.
//

#ifndef LITE_AI_TOOLKIT_TRT_MODNET_H
#define LITE_AI_TOOLKIT_TRT_MODNET_H

#include "lite/trt/core/trt_core.h"
#include "lite/trt/core/trt_utils.h"

namespace trtcv{
class LITE_EXPORTS TRTMODNet : public BasicTRTHandler{
public:
explicit TRTMODNet(const std::string& _trt_model_path,unsigned int _num_threads = 1):
BasicTRTHandler(_trt_model_path, _num_threads)
{};
private:
static constexpr const float mean_val = 127.5f; // RGB
static constexpr const float scale_val = 1.f / 127.5f;
private:
void preprocess(cv::Mat &input_mat);

void generate_matting(float *trt_outputs,
const cv::Mat &mat, types::MattingContent &content,
bool remove_noise = false, bool minimum_post_process = false);
public:
void detect(const cv::Mat &mat, types::MattingContent &content, bool remove_noise = false,
bool minimum_post_process = false);
};
}



#endif //LITE_AI_TOOLKIT_TRT_MODNET_H

0 comments on commit 4252d27

Please sign in to comment.