Skip to content

Commit

Permalink
[Enhancement] Support two-stage rotated detector TensorRT. (open-mmla…
Browse files Browse the repository at this point in the history
…b#530)

* upload

* add fake_multiclass_nms_rotated

* delete unused code

* align with pytorch

* Update delta_midpointoffset_rbbox_coder.py

* add trt rotated roi align

* add index feature in nms

* not good

* fix index

* add ut

* add benchmark

* move to csrc/mmdeploy

* update unit test

Co-authored-by: zytx121 <[email protected]>
  • Loading branch information
2 people authored and grimoire committed Jun 25, 2022
1 parent f4a3251 commit 9c18aee
Show file tree
Hide file tree
Showing 27 changed files with 906 additions and 93 deletions.
43 changes: 23 additions & 20 deletions csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@ static const char* NMS_PLUGIN_VERSION{"1"};
static const char* NMS_PLUGIN_NAME{"TRTBatchedNMS"};
} // namespace

TRTBatchedNMS::TRTBatchedNMS(const std::string& name, NMSParameters params)
: TRTPluginBase(name), param(params) {}
TRTBatchedNMS::TRTBatchedNMS(const std::string& name, NMSParameters params, bool returnIndex)
: TRTPluginBase(name), param(params), mReturnIndex(returnIndex) {}

TRTBatchedNMS::TRTBatchedNMS(const std::string& name, const void* data, size_t length)
: TRTPluginBase(name) {
deserialize_value(&data, &length, &param);
deserialize_value(&data, &length, &boxesSize);
deserialize_value(&data, &length, &scoresSize);
deserialize_value(&data, &length, &numPriors);
deserialize_value(&data, &length, &mClipBoxes);
deserialize_value(&data, &length, &mReturnIndex);
}

int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT { return 2; }
int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT {
int num = mReturnIndex ? 3 : 2;
return num;
}

nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
Expand All @@ -51,6 +52,8 @@ nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions(
case 1:
ret.nbDims = 2;
break;
case 2:
ret.nbDims = 2;
default:
break;
}
Expand Down Expand Up @@ -81,6 +84,7 @@ int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,

void* nmsedDets = outputs[0];
void* nmsedLabels = outputs[1];
void* nmsedIndex = mReturnIndex ? outputs[2] : nullptr;

size_t batch_size = inputDesc[0].dims.d[0];
size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3];
Expand All @@ -94,24 +98,22 @@ int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
pluginStatus_t status = nmsInference(
stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId,
num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold,
DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, workSpace,
param.isNormalized, false, mClipBoxes, rotated);
DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, nmsedIndex,
workSpace, param.isNormalized, false, mClipBoxes, rotated);
ASSERT(status == STATUS_SUCCESS);

return 0;
}

size_t TRTBatchedNMS::getSerializationSize() const TRT_NOEXCEPT {
// NMSParameters, boxesSize,scoresSize,numPriors
return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool);
// NMSParameters
return sizeof(NMSParameters) + sizeof(mClipBoxes) + sizeof(mReturnIndex);
}

void TRTBatchedNMS::serialize(void* buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, param);
serialize_value(&buffer, boxesSize);
serialize_value(&buffer, scoresSize);
serialize_value(&buffer, numPriors);
serialize_value(&buffer, mClipBoxes);
serialize_value(&buffer, mReturnIndex);
}

void TRTBatchedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
Expand All @@ -122,7 +124,7 @@ void TRTBatchedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inp

bool TRTBatchedNMS::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc,
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
if (pos == 3) {
if (pos == 3 || pos == 4) {
return ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
Expand All @@ -135,10 +137,7 @@ const char* TRTBatchedNMS::getPluginType() const TRT_NOEXCEPT { return NMS_PLUGI
const char* TRTBatchedNMS::getPluginVersion() const TRT_NOEXCEPT { return NMS_PLUGIN_VERSION; }

IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT {
auto* plugin = new TRTBatchedNMS(mLayerName, param);
plugin->boxesSize = boxesSize;
plugin->scoresSize = scoresSize;
plugin->numPriors = numPriors;
auto* plugin = new TRTBatchedNMS(mLayerName, param, mReturnIndex);
plugin->setPluginNamespace(mNamespace.c_str());
plugin->setClipParam(mClipBoxes);
return plugin;
Expand All @@ -147,7 +146,7 @@ IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT {
nvinfer1::DataType TRTBatchedNMS::getOutputDataType(int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT {
ASSERT(index >= 0 && index < this->getNbOutputs());
if (index == 1) {
if (index == 1 || index == 2) {
return nvinfer1::DataType::kINT32;
}
return inputTypes[0];
Expand All @@ -167,6 +166,7 @@ TRTBatchedNMSCreator::TRTBatchedNMSCreator() {
PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("return_index", nullptr, PluginFieldType::kINT32, 1));

mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
Expand All @@ -182,6 +182,7 @@ IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(const char* name,
const PluginFieldCollection* fc) TRT_NOEXCEPT {
const PluginField* fields = fc->fields;
bool clipBoxes = true;
bool returnIndex = false;
nvinfer1::plugin::NMSParameters params{};

for (int i = 0; i < fc->nbFields; ++i) {
Expand All @@ -208,10 +209,12 @@ IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(const char* name,
params.isNormalized = *(static_cast<const bool*>(fields[i].data));
} else if (!strcmp(attrName, "clip_boxes")) {
clipBoxes = *(static_cast<const bool*>(fields[i].data));
} else if (!strcmp(attrName, "return_index")) {
returnIndex = *(static_cast<const bool*>(fields[i].data));
}
}

TRTBatchedNMS* plugin = new TRTBatchedNMS(name, params);
TRTBatchedNMS* plugin = new TRTBatchedNMS(name, params, returnIndex);
plugin->setClipParam(clipBoxes);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
#include "NvInferPluginUtils.h"
#include "trt_plugin_base.hpp"
namespace mmdeploy {

enum NMSReturnType { RETURN_DETS = 1, RETURN_INDEX = 1 << 1 };

class TRTBatchedNMS : public TRTPluginBase {
public:
TRTBatchedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param);
TRTBatchedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param, bool returnIndex);

TRTBatchedNMS(const std::string& name, const void* data, size_t length);

Expand Down Expand Up @@ -55,10 +58,8 @@ class TRTBatchedNMS : public TRTPluginBase {

private:
nvinfer1::plugin::NMSParameters param{};
int boxesSize{};
int scoresSize{};
int numPriors{};
bool mClipBoxes{};
bool mReturnIndex{};
};

class TRTBatchedNMSCreator : public TRTPluginCreatorBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, NMSParameter
TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length)
: TRTPluginBase(name) {
deserialize_value(&data, &length, &param);
deserialize_value(&data, &length, &boxesSize);
deserialize_value(&data, &length, &scoresSize);
deserialize_value(&data, &length, &numPriors);
deserialize_value(&data, &length, &mClipBoxes);
}

Expand Down Expand Up @@ -94,23 +91,20 @@ int TRTBatchedRotatedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
pluginStatus_t status = nmsInference(
stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId,
num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold,
DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, workSpace,
param.isNormalized, false, mClipBoxes, rotated);
DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, nullptr,
workSpace, param.isNormalized, false, mClipBoxes, rotated);
ASSERT(status == STATUS_SUCCESS);

return 0;
}

size_t TRTBatchedRotatedNMS::getSerializationSize() const TRT_NOEXCEPT {
// NMSParameters, boxesSize,scoresSize,numPriors
return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool);
// NMSParameters,
return sizeof(NMSParameters) + sizeof(bool);
}

void TRTBatchedRotatedNMS::serialize(void* buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, param);
serialize_value(&buffer, boxesSize);
serialize_value(&buffer, scoresSize);
serialize_value(&buffer, numPriors);
serialize_value(&buffer, mClipBoxes);
}

Expand Down Expand Up @@ -140,9 +134,6 @@ const char* TRTBatchedRotatedNMS::getPluginVersion() const TRT_NOEXCEPT {

IPluginV2DynamicExt* TRTBatchedRotatedNMS::clone() const TRT_NOEXCEPT {
auto* plugin = new TRTBatchedRotatedNMS(mLayerName, param);
plugin->boxesSize = boxesSize;
plugin->scoresSize = scoresSize;
plugin->numPriors = numPriors;
plugin->setPluginNamespace(mNamespace.c_str());
plugin->setClipParam(mClipBoxes);
return plugin;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ class TRTBatchedRotatedNMS : public TRTPluginBase {

private:
nvinfer1::plugin::NMSParameters param{};
int boxesSize{};
int scoresSize{};
int numPriors{};
bool mClipBoxes{};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
const float scoreThreshold, const float iouThreshold,
const DataType DT_BBOX, const void* locData, const DataType DT_SCORE,
const void* confData, void* nmsedDets, void* nmsedLabels,
void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes,
bool rotated = false);
void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid,
bool clipBoxes, bool rotated = false);

#endif
3 changes: 2 additions & 1 deletion csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ pluginStatus_t gatherNMSOutputs(cudaStream_t stream, bool shareLocation, int num
int numPredsPerClass, int numClasses, int topK, int keepTopK,
DataType DT_BBOX, DataType DT_SCORE, const void* indices,
const void* scores, const void* bboxData, void* nmsedDets,
void* nmsedLabels, bool clipBoxes = true, bool rotated = false);
void* nmsedLabels, void* nmsedIndex = nullptr,
bool clipBoxes = true, bool rotated = false);

size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses,
int numPredsPerClass, int topK, DataType DT_BBOX,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ pluginStatus_t allClassNMS_gpu(cudaStream_t stream, const int num, const int num
const int GS = num_classes;
const int t_size = (top_k + BS - 1) / BS;

ASSERT(t_size <= 10);
kernel[t_size - 1]<<<GS, BS, BS * t_size * sizeof(bool), stream>>>(
num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized,
(T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ pluginStatus_t allClassRotatedNMS_gpu(cudaStream_t stream, const int num, const
const int GS = num_classes;
const int t_size = (top_k + BS - 1) / BS;

ASSERT(t_size <= 10);
kernel[t_size - 1]<<<GS, BS, BS * t_size * sizeof(bool), stream>>>(
num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized,
(T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
const float scoreThreshold, const float iouThreshold,
const DataType DT_BBOX, const void* locData, const DataType DT_SCORE,
const void* confData, void* nmsedDets, void* nmsedLabels,
void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes,
bool rotated) {
void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid,
bool clipBoxes, bool rotated) {
const int topKVal = topK < 0 ? numPredsPerClass : topK;
const int keepTopKVal = keepTopK < 0 ? numPredsPerClass : keepTopK;
// locCount = batch_size * number_boxes_per_sample * 4
Expand Down Expand Up @@ -117,7 +117,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch
// Gather data from the sorted bounding boxes after NMS
status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass, numClasses, topKVal,
keepTopKVal, DataType::kFLOAT, DataType::kFLOAT, indices, scores,
bboxData, nmsedDets, nmsedLabels, clipBoxes, rotated);
bboxData, nmsedDets, nmsedLabels, nmsedIndex, clipBoxes, rotated);

ASSERT_FAILURE(status == STATUS_SUCCESS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ __launch_bounds__(nthds_per_cta) __global__
const int numPredsPerClass, const int numClasses, const int topK,
const int keepTopK, const int *indices, const T_SCORE *scores,
const T_BBOX *bboxData, T_BBOX *nmsedDets, int *nmsedLabels,
bool clipBoxes) {
int *nmsedIndex, bool clipBoxes) {
if (keepTopK > topK) return;
for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; i < numImages * keepTopK;
i += gridDim.x * nthds_per_cta) {
Expand All @@ -23,6 +23,9 @@ __launch_bounds__(nthds_per_cta) __global__
const T_SCORE score = scores[offset + detId];
if (index == -1) {
nmsedLabels[i] = -1;
if (nmsedIndex != nullptr) {
nmsedIndex[i] = -1;
}
if (rotated) {
nmsedDets[i * 6] = 0;
nmsedDets[i * 6 + 1] = 0;
Expand All @@ -46,6 +49,9 @@ __launch_bounds__(nthds_per_cta) __global__
: index % (numClasses * numPredsPerClass)) +
bboxOffset) *
5;
if (nmsedIndex != nullptr) {
nmsedIndex[i] = bboxId / 5;
}
// clipped bbox xmin
nmsedDets[i * 6] =
clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId];
Expand All @@ -67,6 +73,9 @@ __launch_bounds__(nthds_per_cta) __global__
: index % (numClasses * numPredsPerClass)) +
bboxOffset) *
4;
if (nmsedIndex != nullptr) {
nmsedIndex[i] = bboxId / 4;
}
// clipped bbox xmin
nmsedDets[i * 5] =
clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId];
Expand All @@ -90,12 +99,14 @@ pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocatio
const int numImages, const int numPredsPerClass,
const int numClasses, const int topK, const int keepTopK,
const void *indices, const void *scores, const void *bboxData,
void *nmsedDets, void *nmsedLabels, bool clipBoxes) {
void *nmsedDets, void *nmsedLabels, void *nmsedIndex,
bool clipBoxes) {
const int BS = 32;
const int GS = 32;
gatherNMSOutputs_kernel<T_BBOX, T_SCORE, rotated, BS><<<GS, BS, 0, stream>>>(
shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK, (int *)indices,
(T_SCORE *)scores, (T_BBOX *)bboxData, (T_BBOX *)nmsedDets, (int *)nmsedLabels, clipBoxes);
(T_SCORE *)scores, (T_BBOX *)bboxData, (T_BBOX *)nmsedDets, (int *)nmsedLabels,
(int *)nmsedIndex, clipBoxes);

CSC(cudaGetLastError(), STATUS_FAILURE);
return STATUS_SUCCESS;
Expand All @@ -104,7 +115,7 @@ pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocatio
// gatherNMSOutputs LAUNCH CONFIG {{{
typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int, const int, const int,
const int, const int, const void *, const void *, const void *,
void *, void *, bool);
void *, void *, void *, bool);
struct nmsOutLaunchConfig {
DataType t_bbox;
DataType t_score;
Expand Down Expand Up @@ -138,14 +149,15 @@ pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation, c
const int numPredsPerClass, const int numClasses, const int topK,
const int keepTopK, const DataType DT_BBOX, const DataType DT_SCORE,
const void *indices, const void *scores, const void *bboxData,
void *nmsedDets, void *nmsedLabels, bool clipBoxes, bool rotated) {
void *nmsedDets, void *nmsedLabels, void *nmsedIndex,
bool clipBoxes, bool rotated) {
nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE, rotated);
for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) {
if (lc == nmsOutFuncVec[i]) {
DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i);
return nmsOutFuncVec[i].function(stream, shareLocation, numImages, numPredsPerClass,
numClasses, topK, keepTopK, indices, scores, bboxData,
nmsedDets, nmsedLabels, clipBoxes);
nmsedDets, nmsedLabels, nmsedIndex, clipBoxes);
}
}
return STATUS_BAD_PARAM;
Expand Down
Loading

0 comments on commit 9c18aee

Please sign in to comment.