diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp index b8150b992a..431f2dd63b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp @@ -18,19 +18,20 @@ static const char* NMS_PLUGIN_VERSION{"1"}; static const char* NMS_PLUGIN_NAME{"TRTBatchedNMS"}; } // namespace -TRTBatchedNMS::TRTBatchedNMS(const std::string& name, NMSParameters params) - : TRTPluginBase(name), param(params) {} +TRTBatchedNMS::TRTBatchedNMS(const std::string& name, NMSParameters params, bool returnIndex) + : TRTPluginBase(name), param(params), mReturnIndex(returnIndex) {} TRTBatchedNMS::TRTBatchedNMS(const std::string& name, const void* data, size_t length) : TRTPluginBase(name) { deserialize_value(&data, &length, ¶m); - deserialize_value(&data, &length, &boxesSize); - deserialize_value(&data, &length, &scoresSize); - deserialize_value(&data, &length, &numPriors); deserialize_value(&data, &length, &mClipBoxes); + deserialize_value(&data, &length, &mReturnIndex); } -int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT { return 2; } +int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT { + int num = mReturnIndex ? 3 : 2; + return num; +} nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions( int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, @@ -51,6 +52,8 @@ nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions( case 1: ret.nbDims = 2; break; + case 2: + ret.nbDims = 2; default: break; } @@ -81,6 +84,7 @@ int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, void* nmsedDets = outputs[0]; void* nmsedLabels = outputs[1]; + void* nmsedIndex = mReturnIndex ? outputs[2] : nullptr; size_t batch_size = inputDesc[0].dims.d[0]; size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3]; @@ -94,24 +98,22 @@ int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, pluginStatus_t status = nmsInference( stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId, num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold, - DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, workSpace, - param.isNormalized, false, mClipBoxes, rotated); + DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, nmsedIndex, + workSpace, param.isNormalized, false, mClipBoxes, rotated); ASSERT(status == STATUS_SUCCESS); return 0; } size_t TRTBatchedNMS::getSerializationSize() const TRT_NOEXCEPT { - // NMSParameters, boxesSize,scoresSize,numPriors - return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool); + // NMSParameters + return sizeof(NMSParameters) + sizeof(mClipBoxes) + sizeof(mReturnIndex); } void TRTBatchedNMS::serialize(void* buffer) const TRT_NOEXCEPT { serialize_value(&buffer, param); - serialize_value(&buffer, boxesSize); - serialize_value(&buffer, scoresSize); - serialize_value(&buffer, numPriors); serialize_value(&buffer, mClipBoxes); + serialize_value(&buffer, mReturnIndex); } void TRTBatchedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, @@ -122,7 +124,7 @@ void TRTBatchedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inp bool TRTBatchedNMS::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 3) { + if (pos == 3 || pos == 4) { return ioDesc[pos].type == nvinfer1::DataType::kINT32 && ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; } @@ -135,10 +137,7 @@ const char* TRTBatchedNMS::getPluginType() const TRT_NOEXCEPT { return NMS_PLUGI const char* TRTBatchedNMS::getPluginVersion() const TRT_NOEXCEPT { return NMS_PLUGIN_VERSION; } IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT { - auto* plugin = new TRTBatchedNMS(mLayerName, param); - plugin->boxesSize = boxesSize; - plugin->scoresSize = scoresSize; - plugin->numPriors = numPriors; + auto* plugin = new TRTBatchedNMS(mLayerName, param, mReturnIndex); plugin->setPluginNamespace(mNamespace.c_str()); plugin->setClipParam(mClipBoxes); return plugin; @@ -147,7 +146,7 @@ IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT { nvinfer1::DataType TRTBatchedNMS::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT { ASSERT(index >= 0 && index < this->getNbOutputs()); - if (index == 1) { + if (index == 1 || index == 2) { return nvinfer1::DataType::kINT32; } return inputTypes[0]; @@ -167,6 +166,7 @@ TRTBatchedNMSCreator::TRTBatchedNMSCreator() { PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("return_index", nullptr, PluginFieldType::kINT32, 1)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); @@ -182,6 +182,7 @@ IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT { const PluginField* fields = fc->fields; bool clipBoxes = true; + bool returnIndex = false; nvinfer1::plugin::NMSParameters params{}; for (int i = 0; i < fc->nbFields; ++i) { @@ -208,10 +209,12 @@ IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(const char* name, params.isNormalized = *(static_cast(fields[i].data)); } else if (!strcmp(attrName, "clip_boxes")) { clipBoxes = *(static_cast(fields[i].data)); + } else if (!strcmp(attrName, "return_index")) { + returnIndex = *(static_cast(fields[i].data)); } } - TRTBatchedNMS* plugin = new TRTBatchedNMS(name, params); + TRTBatchedNMS* plugin = new TRTBatchedNMS(name, params, returnIndex); plugin->setClipParam(clipBoxes); plugin->setPluginNamespace(mNamespace.c_str()); return plugin; diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp index f37805213c..d1e5d643db 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp @@ -9,9 +9,12 @@ #include "NvInferPluginUtils.h" #include "trt_plugin_base.hpp" namespace mmdeploy { + +enum NMSReturnType { RETURN_DETS = 1, RETURN_INDEX = 1 << 1 }; + class TRTBatchedNMS : public TRTPluginBase { public: - TRTBatchedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param); + TRTBatchedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param, bool returnIndex); TRTBatchedNMS(const std::string& name, const void* data, size_t length); @@ -55,10 +58,8 @@ class TRTBatchedNMS : public TRTPluginBase { private: nvinfer1::plugin::NMSParameters param{}; - int boxesSize{}; - int scoresSize{}; - int numPriors{}; bool mClipBoxes{}; + bool mReturnIndex{}; }; class TRTBatchedNMSCreator : public TRTPluginCreatorBase { diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp index d478ee797d..9d977bc937 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp @@ -23,9 +23,6 @@ TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, NMSParameter TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length) : TRTPluginBase(name) { deserialize_value(&data, &length, ¶m); - deserialize_value(&data, &length, &boxesSize); - deserialize_value(&data, &length, &scoresSize); - deserialize_value(&data, &length, &numPriors); deserialize_value(&data, &length, &mClipBoxes); } @@ -94,23 +91,20 @@ int TRTBatchedRotatedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, pluginStatus_t status = nmsInference( stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId, num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold, - DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, workSpace, - param.isNormalized, false, mClipBoxes, rotated); + DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, nullptr, + workSpace, param.isNormalized, false, mClipBoxes, rotated); ASSERT(status == STATUS_SUCCESS); return 0; } size_t TRTBatchedRotatedNMS::getSerializationSize() const TRT_NOEXCEPT { - // NMSParameters, boxesSize,scoresSize,numPriors - return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool); + // NMSParameters, + return sizeof(NMSParameters) + sizeof(bool); } void TRTBatchedRotatedNMS::serialize(void* buffer) const TRT_NOEXCEPT { serialize_value(&buffer, param); - serialize_value(&buffer, boxesSize); - serialize_value(&buffer, scoresSize); - serialize_value(&buffer, numPriors); serialize_value(&buffer, mClipBoxes); } @@ -140,9 +134,6 @@ const char* TRTBatchedRotatedNMS::getPluginVersion() const TRT_NOEXCEPT { IPluginV2DynamicExt* TRTBatchedRotatedNMS::clone() const TRT_NOEXCEPT { auto* plugin = new TRTBatchedRotatedNMS(mLayerName, param); - plugin->boxesSize = boxesSize; - plugin->scoresSize = scoresSize; - plugin->numPriors = numPriors; plugin->setPluginNamespace(mNamespace.c_str()); plugin->setClipParam(mClipBoxes); return plugin; diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp index 9e7de526a3..66479eb7e7 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp @@ -54,9 +54,6 @@ class TRTBatchedRotatedNMS : public TRTPluginBase { private: nvinfer1::plugin::NMSParameters param{}; - int boxesSize{}; - int scoresSize{}; - int numPriors{}; bool mClipBoxes{}; }; diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp index b4929bb675..22cffa0605 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp @@ -13,7 +13,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch const float scoreThreshold, const float iouThreshold, const DataType DT_BBOX, const void* locData, const DataType DT_SCORE, const void* confData, void* nmsedDets, void* nmsedLabels, - void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes, - bool rotated = false); + void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid, + bool clipBoxes, bool rotated = false); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h index 1b9561b19c..1b50fa4e9f 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h @@ -76,7 +76,8 @@ pluginStatus_t gatherNMSOutputs(cudaStream_t stream, bool shareLocation, int num int numPredsPerClass, int numClasses, int topK, int keepTopK, DataType DT_BBOX, DataType DT_SCORE, const void* indices, const void* scores, const void* bboxData, void* nmsedDets, - void* nmsedLabels, bool clipBoxes = true, bool rotated = false); + void* nmsedLabels, void* nmsedIndex = nullptr, + bool clipBoxes = true, bool rotated = false); size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses, int numPredsPerClass, int topK, DataType DT_BBOX, diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu index 0a1e3e283a..d048a36eff 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu @@ -200,6 +200,7 @@ pluginStatus_t allClassNMS_gpu(cudaStream_t stream, const int num, const int num const int GS = num_classes; const int t_size = (top_k + BS - 1) / BS; + ASSERT(t_size <= 10); kernel[t_size - 1]<<>>( num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized, (T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array, diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu index a5e102fb55..8d3858deae 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu @@ -430,6 +430,7 @@ pluginStatus_t allClassRotatedNMS_gpu(cudaStream_t stream, const int num, const const int GS = num_classes; const int t_size = (top_k + BS - 1) / BS; + ASSERT(t_size <= 10); kernel[t_size - 1]<<>>( num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized, (T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array, diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp index 6be5293a3d..71cb7a8592 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp @@ -10,8 +10,8 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch const float scoreThreshold, const float iouThreshold, const DataType DT_BBOX, const void* locData, const DataType DT_SCORE, const void* confData, void* nmsedDets, void* nmsedLabels, - void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes, - bool rotated) { + void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid, + bool clipBoxes, bool rotated) { const int topKVal = topK < 0 ? numPredsPerClass : topK; const int keepTopKVal = keepTopK < 0 ? numPredsPerClass : keepTopK; // locCount = batch_size * number_boxes_per_sample * 4 @@ -117,7 +117,7 @@ pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatch // Gather data from the sorted bounding boxes after NMS status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass, numClasses, topKVal, keepTopKVal, DataType::kFLOAT, DataType::kFLOAT, indices, scores, - bboxData, nmsedDets, nmsedLabels, clipBoxes, rotated); + bboxData, nmsedDets, nmsedLabels, nmsedIndex, clipBoxes, rotated); ASSERT_FAILURE(status == STATUS_SUCCESS); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu index c86ccab5c7..8a0ec7bac8 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu @@ -12,7 +12,7 @@ __launch_bounds__(nthds_per_cta) __global__ const int numPredsPerClass, const int numClasses, const int topK, const int keepTopK, const int *indices, const T_SCORE *scores, const T_BBOX *bboxData, T_BBOX *nmsedDets, int *nmsedLabels, - bool clipBoxes) { + int *nmsedIndex, bool clipBoxes) { if (keepTopK > topK) return; for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; i < numImages * keepTopK; i += gridDim.x * nthds_per_cta) { @@ -23,6 +23,9 @@ __launch_bounds__(nthds_per_cta) __global__ const T_SCORE score = scores[offset + detId]; if (index == -1) { nmsedLabels[i] = -1; + if (nmsedIndex != nullptr) { + nmsedIndex[i] = -1; + } if (rotated) { nmsedDets[i * 6] = 0; nmsedDets[i * 6 + 1] = 0; @@ -46,6 +49,9 @@ __launch_bounds__(nthds_per_cta) __global__ : index % (numClasses * numPredsPerClass)) + bboxOffset) * 5; + if (nmsedIndex != nullptr) { + nmsedIndex[i] = bboxId / 5; + } // clipped bbox xmin nmsedDets[i * 6] = clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; @@ -67,6 +73,9 @@ __launch_bounds__(nthds_per_cta) __global__ : index % (numClasses * numPredsPerClass)) + bboxOffset) * 4; + if (nmsedIndex != nullptr) { + nmsedIndex[i] = bboxId / 4; + } // clipped bbox xmin nmsedDets[i * 5] = clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; @@ -90,12 +99,14 @@ pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocatio const int numImages, const int numPredsPerClass, const int numClasses, const int topK, const int keepTopK, const void *indices, const void *scores, const void *bboxData, - void *nmsedDets, void *nmsedLabels, bool clipBoxes) { + void *nmsedDets, void *nmsedLabels, void *nmsedIndex, + bool clipBoxes) { const int BS = 32; const int GS = 32; gatherNMSOutputs_kernel<<>>( shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK, (int *)indices, - (T_SCORE *)scores, (T_BBOX *)bboxData, (T_BBOX *)nmsedDets, (int *)nmsedLabels, clipBoxes); + (T_SCORE *)scores, (T_BBOX *)bboxData, (T_BBOX *)nmsedDets, (int *)nmsedLabels, + (int *)nmsedIndex, clipBoxes); CSC(cudaGetLastError(), STATUS_FAILURE); return STATUS_SUCCESS; @@ -104,7 +115,7 @@ pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocatio // gatherNMSOutputs LAUNCH CONFIG {{{ typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int, const int, const int, const int, const int, const void *, const void *, const void *, - void *, void *, bool); + void *, void *, void *, bool); struct nmsOutLaunchConfig { DataType t_bbox; DataType t_score; @@ -138,14 +149,15 @@ pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation, c const int numPredsPerClass, const int numClasses, const int topK, const int keepTopK, const DataType DT_BBOX, const DataType DT_SCORE, const void *indices, const void *scores, const void *bboxData, - void *nmsedDets, void *nmsedLabels, bool clipBoxes, bool rotated) { + void *nmsedDets, void *nmsedLabels, void *nmsedIndex, + bool clipBoxes, bool rotated) { nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE, rotated); for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) { if (lc == nmsOutFuncVec[i]) { DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i); return nmsOutFuncVec[i].function(stream, shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK, indices, scores, bboxData, - nmsedDets, nmsedLabels, clipBoxes); + nmsedDets, nmsedLabels, nmsedIndex, clipBoxes); } } return STATUS_BAD_PARAM; diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp new file mode 100644 index 0000000000..6637603128 --- /dev/null +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp @@ -0,0 +1,228 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "trt_multi_level_rotated_roi_align.hpp" + +#include + +#include + +#include "trt_multi_level_rotated_roi_align_kernel.hpp" +#include "trt_plugin_helper.hpp" +#include "trt_serialize.hpp" +namespace mmdeploy { +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *PLUGIN_NAME{"MMCVMultiLevelRotatedRoiAlign"}; +} // namespace + +TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign( + const std::string &name, int alignedHeight, int alignedWidth, int clockwise, int sampleNum, + const std::vector &featmapStrides, float roiScaleFactor, int finestScale, bool aligned) + : TRTPluginBase(name), + mAlignedHeight(alignedHeight), + mAlignedWidth(alignedWidth), + mClockwise(clockwise), + mSampleNum(sampleNum), + mFeatmapStrides(featmapStrides), + mRoiScaleFactor(roiScaleFactor), + mFinestScale(finestScale), + mAligned(aligned) {} + +TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign(const std::string name, const void *data, + size_t length) + : TRTPluginBase(name) { + deserialize_value(&data, &length, &mAlignedHeight); + deserialize_value(&data, &length, &mAlignedWidth); + deserialize_value(&data, &length, &mClockwise); + deserialize_value(&data, &length, &mSampleNum); + deserialize_value(&data, &length, &mRoiScaleFactor); + deserialize_value(&data, &length, &mFinestScale); + deserialize_value(&data, &length, &mAligned); + deserialize_value(&data, &length, &mFeatmapStrides); +} + +nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRotatedRoiAlign::clone() const TRT_NOEXCEPT { + TRTMultiLevelRotatedRoiAlign *plugin = new TRTMultiLevelRotatedRoiAlign( + mLayerName, mAlignedHeight, mAlignedWidth, mClockwise, mSampleNum, mFeatmapStrides, + mRoiScaleFactor, mFinestScale, mAligned); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs TRTMultiLevelRotatedRoiAlign::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { + // warning, nbInputs should equal to mFeatmapStrides.size() + 1 + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[1].d[1]; + ret.d[2] = exprBuilder.constant(mAlignedHeight); + ret.d[3] = exprBuilder.constant(mAlignedWidth); + + return ret; +} + +bool TRTMultiLevelRotatedRoiAlign::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; +} + +void TRTMultiLevelRotatedRoiAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, + int nbOutputs) TRT_NOEXCEPT { + // Validate input arguments + ASSERT(nbOutputs == 1); + ASSERT(nbInputs >= 1); + mFeatmapStrides = + std::vector(mFeatmapStrides.begin(), mFeatmapStrides.begin() + nbInputs - 1); +} + +size_t TRTMultiLevelRotatedRoiAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT { + return 0; +} + +int TRTMultiLevelRotatedRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, + void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { + int num_rois = inputDesc[0].dims.d[0]; + int batch_size = inputDesc[1].dims.d[0]; + int channels = inputDesc[1].dims.d[1]; + + const int kMaxFeatMap = 10; + int heights[kMaxFeatMap]; + int widths[kMaxFeatMap]; + float strides[kMaxFeatMap]; + + int num_feats = mFeatmapStrides.size(); + for (int i = 0; i < num_feats; ++i) { + heights[i] = inputDesc[i + 1].dims.d[2]; + widths[i] = inputDesc[i + 1].dims.d[3]; + strides[i] = mFeatmapStrides[i]; + } + + const void *rois = inputs[0]; + const void *const *feats = inputs + 1; + + multi_level_rotated_roi_align((float *)outputs[0], (const float *)rois, num_rois, feats, + num_feats, batch_size, channels, &heights[0], &widths[0], + &strides[0], mAlignedHeight, mAlignedWidth, mClockwise, + mSampleNum, mRoiScaleFactor, mFinestScale, mAligned, stream); + + return 0; +} + +nvinfer1::DataType TRTMultiLevelRotatedRoiAlign::getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { + return nvinfer1::DataType::kFLOAT; +} + +// IPluginV2 Methods +const char *TRTMultiLevelRotatedRoiAlign::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } + +const char *TRTMultiLevelRotatedRoiAlign::getPluginVersion() const TRT_NOEXCEPT { + return PLUGIN_VERSION; +} + +int TRTMultiLevelRotatedRoiAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; } + +size_t TRTMultiLevelRotatedRoiAlign::getSerializationSize() const TRT_NOEXCEPT { + return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + + serialized_size(mAlignedWidth) + serialized_size(mClockwise) + + serialized_size(mSampleNum) + serialized_size(mRoiScaleFactor) + + serialized_size(mFinestScale) + serialized_size(mAligned); +} + +void TRTMultiLevelRotatedRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT { + serialize_value(&buffer, mAlignedHeight); + serialize_value(&buffer, mAlignedWidth); + serialize_value(&buffer, mClockwise); + serialize_value(&buffer, mSampleNum); + serialize_value(&buffer, mRoiScaleFactor); + serialize_value(&buffer, mFinestScale); + serialize_value(&buffer, mAligned); + serialize_value(&buffer, mFeatmapStrides); +} + +TRTMultiLevelRotatedRoiAlignCreator::TRTMultiLevelRotatedRoiAlignCreator() { + mPluginAttributes = std::vector( + {nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"), + nvinfer1::PluginField("clockwise"), nvinfer1::PluginField("sampling_ratio"), + nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"), + nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")}); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *TRTMultiLevelRotatedRoiAlignCreator::getPluginName() const TRT_NOEXCEPT { + return PLUGIN_NAME; +} + +const char *TRTMultiLevelRotatedRoiAlignCreator::getPluginVersion() const TRT_NOEXCEPT { + return PLUGIN_VERSION; +} + +nvinfer1::IPluginV2 *TRTMultiLevelRotatedRoiAlignCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { + int alignedHeight = 7; + int alignedWidth = 7; + int clockwise = 0; + int sampleNum = 2; + std::vector featmapStrides; + float roiScaleFactor = -1; + int finestScale = 56; + bool aligned = false; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("output_height") == 0) { + alignedHeight = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("output_width") == 0) { + alignedWidth = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("clockwise") == 0) { + clockwise = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("sampling_ratio") == 0) { + sampleNum = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("roi_scale_factor") == 0) { + roiScaleFactor = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("finest_scale") == 0) { + finestScale = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("featmap_strides") == 0) { + int data_size = (fc->fields[i].length); + const float *data_start = static_cast(fc->fields[i].data); + featmapStrides = std::vector(data_start, data_start + data_size); + } else if (field_name.compare("aligned") == 0) { + int aligned_int = static_cast(fc->fields[i].data)[0]; + aligned = aligned_int != 0; + } + } + + ASSERT(featmapStrides.size() != 0); + + TRTMultiLevelRotatedRoiAlign *plugin = + new TRTMultiLevelRotatedRoiAlign(name, alignedHeight, alignedWidth, clockwise, sampleNum, + featmapStrides, roiScaleFactor, finestScale, aligned); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *TRTMultiLevelRotatedRoiAlignCreator::deserializePlugin( + const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { + auto plugin = new TRTMultiLevelRotatedRoiAlign(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +REGISTER_TENSORRT_PLUGIN(TRTMultiLevelRotatedRoiAlignCreator); +} // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp new file mode 100644 index 0000000000..cf0bab7584 --- /dev/null +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp @@ -0,0 +1,79 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_HPP +#define TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_HPP + +#include + +#include +#include +#include + +#include "trt_plugin_base.hpp" + +namespace mmdeploy { +class TRTMultiLevelRotatedRoiAlign : public TRTPluginBase { + public: + TRTMultiLevelRotatedRoiAlign(const std::string &name, int alignedHeight, int alignedWidth, + int clockwise, int sampleNum, + const std::vector &featmapStrides, float roiScaleFactor = -1, + int finestScale = 56, bool aligned = false); + + TRTMultiLevelRotatedRoiAlign(const std::string name, const void *data, size_t length); + + TRTMultiLevelRotatedRoiAlign() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, + int nbInputs, nvinfer1::IExprBuilder &exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char *getPluginType() const TRT_NOEXCEPT override; + const char *getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void *buffer) const TRT_NOEXCEPT override; + + private: + int mAlignedHeight; + int mAlignedWidth; + int mClockwise; + int mSampleNum; + std::vector mFeatmapStrides; + float mRoiScaleFactor; + int mFinestScale; + bool mAligned; +}; + +class TRTMultiLevelRotatedRoiAlignCreator : public TRTPluginCreatorBase { + public: + TRTMultiLevelRotatedRoiAlignCreator(); + + const char *getPluginName() const TRT_NOEXCEPT override; + + const char *getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, + size_t serialLength) TRT_NOEXCEPT override; +}; +} // namespace mmdeploy +#endif // TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu new file mode 100644 index 0000000000..1c6f292bae --- /dev/null +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu @@ -0,0 +1,164 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include +#include + +#include +#include + +#include "common_cuda_helper.hpp" +#include "trt_multi_level_rotated_roi_align_kernel.hpp" +#include "trt_plugin_helper.hpp" + +const int kMAX_FEATMAP_SIZE = 10; +struct FeatData { + const void *data[kMAX_FEATMAP_SIZE]; + int batch_size; + int channels; + int h[kMAX_FEATMAP_SIZE]; + int w[kMAX_FEATMAP_SIZE]; + float spatial_scale[kMAX_FEATMAP_SIZE]; + int num_featmap; +}; + +template +__device__ scalar_t roi_align_single(const scalar_t *__restrict__ bottom_data, + const int roi_batch_ind, scalar_t roi_center_w, + scalar_t roi_center_h, scalar_t roi_width, scalar_t roi_height, + scalar_t theta, const scalar_t spatial_scale, const int pw, + const int ph, const int c, const int sample_num, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width) { + // Force malformed ROIs to be 1x1 + + roi_width = max(roi_width, (scalar_t)1.); + roi_height = max(roi_height, (scalar_t)1.); + + const scalar_t bin_size_h = roi_height / scalar_t(pooled_height); + const scalar_t bin_size_w = roi_width / scalar_t(pooled_width); + + const scalar_t *offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + + const int roi_bin_grid_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); + const int roi_bin_grid_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); + + const scalar_t roi_start_h = -roi_height / scalar_t(2.0); + const scalar_t roi_start_w = -roi_width / scalar_t(2.0); + const scalar_t cosscalar_theta = cos(theta); + const scalar_t sinscalar_theta = sin(theta); + + // We do average (integral) pooling inside a bin + const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + scalar_t output_val = 0.; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1 + const scalar_t yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const scalar_t xx = + roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + // Rotate by theta (counterclockwise) around the center and translate + scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h; + scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w; + + scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); + output_val += val; + } + } + + return output_val / count; +} + +template +__global__ void rotated_roi_extractor_kernel(scalar_t *__restrict__ output, + const scalar_t *__restrict__ bottom_rois, + FeatData feat_data, const int clockwise, + const int sample_num, const float roi_scale_factor, + const int finest_scale, const int pooled_height, + const int pooled_width, int nThreads) { + CUDA_1D_KERNEL_LOOP(index, nThreads) { + const int channels = feat_data.channels; + int tmp_index = index; + const int pw = tmp_index % pooled_width; + tmp_index /= pooled_width; + const int ph = tmp_index % pooled_height; + tmp_index /= pooled_height; + const int c = tmp_index % channels; + const int n = tmp_index / channels; + + const scalar_t *offset_bottom_rois = bottom_rois + n * 6; + + scalar_t roi_offset_x0 = offset_bottom_rois[1]; + scalar_t roi_offset_y0 = offset_bottom_rois[2]; + scalar_t roi_offset_width = offset_bottom_rois[3]; + scalar_t roi_offset_height = offset_bottom_rois[4]; + scalar_t theta = offset_bottom_rois[5]; + + const scalar_t scale = sqrtf(roi_offset_width * roi_offset_height); + + const int target_lvls = + min(feat_data.num_featmap - 1, + max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); + + if (roi_scale_factor > 0.) { + roi_offset_width = roi_offset_width * roi_scale_factor; + roi_offset_height = roi_offset_height * roi_scale_factor; + } + + const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; + const int height = feat_data.h[target_lvls]; + const int width = feat_data.w[target_lvls]; + const scalar_t *bottom_data = (scalar_t *)feat_data.data[target_lvls]; + + const int roi_batch_ind = offset_bottom_rois[0]; + const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; + const scalar_t roi_center_w = fma(roi_offset_x0, spatial_scale, offset); + const scalar_t roi_center_h = fma(roi_offset_y0, spatial_scale, offset); + const scalar_t roi_width = roi_offset_width * spatial_scale; + const scalar_t roi_height = roi_offset_height * spatial_scale; + + theta = clockwise > 0 ? -theta : theta; + + const scalar_t output_val = roi_align_single( + bottom_data, roi_batch_ind, roi_center_w, roi_center_h, roi_width, roi_height, theta, + spatial_scale, pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width); + output[index] = output_val; + } +} + +template +void multi_level_rotated_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, + int num_feats, int n, int c, int *h, int *w, float *strides, + int aligned_height, int aligned_width, int clockwise, + int sample_num, float roi_scale_factor, int finest_scale, + bool aligned, cudaStream_t stream) { + FeatData feat_data; + feat_data.batch_size = n; + feat_data.channels = c; + feat_data.num_featmap = num_feats; + for (int i = 0; i < num_feats; ++i) { + feat_data.data[i] = feats[i]; + feat_data.h[i] = h[i]; + feat_data.w[i] = w[i]; + feat_data.spatial_scale[i] = 1. / float(strides[i]); + } + int nThreads = num_rois * c * aligned_height * aligned_width; + if (aligned) { + rotated_roi_extractor_kernel<<>>( + output, rois, feat_data, clockwise, sample_num, roi_scale_factor, finest_scale, + aligned_height, aligned_width, nThreads); + } else { + rotated_roi_extractor_kernel<<>>( + output, rois, feat_data, clockwise, sample_num, roi_scale_factor, finest_scale, + aligned_height, aligned_width, nThreads); + } +} + +template void multi_level_rotated_roi_align( + float *output, const float *rois, int num_rois, const void *const *feats, int num_feats, int n, + int c, int *h, int *w, float *strides, int aligned_height, int aligned_width, int clockwise, + int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp new file mode 100644 index 0000000000..fc3700df3b --- /dev/null +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp @@ -0,0 +1,13 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP +#define TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP +#include + +template +void multi_level_rotated_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, + int num_feats, int n, int c, int *h, int *w, float *strides, + int aligned_height, int aligned_width, int clockwise, + int sample_num, float roi_scale_factor, int finest_scale, + bool aligned, cudaStream_t stream); + +#endif // TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP diff --git a/docs/en/03-benchmark/benchmark.md b/docs/en/03-benchmark/benchmark.md index 14580c8c1c..b6fa4b3c11 100644 --- a/docs/en/03-benchmark/benchmark.md +++ b/docs/en/03-benchmark/benchmark.md @@ -1580,8 +1580,8 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../ mAP 0.756 0.756 - - - - + 0.758 + 0.730 - - diff --git a/docs/en/04-supported-codebases/mmrotate.md b/docs/en/04-supported-codebases/mmrotate.md index 22f15a2938..1b9e403ae1 100644 --- a/docs/en/04-supported-codebases/mmrotate.md +++ b/docs/en/04-supported-codebases/mmrotate.md @@ -11,7 +11,7 @@ Please refer to [official installation guide](https://mmrotate.readthedocs.io/en | Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config | | :--------------- | :--------------- | :----------: | :------: | :--: | :---: | :------: | :--------------------------------------------------------------------------------------------: | | RotatedRetinaNet | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | -| Oriented RCNN | RotatedDetection | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | +| Oriented RCNN | RotatedDetection | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | ### Example diff --git a/docs/zh_cn/03-benchmark/benchmark.md b/docs/zh_cn/03-benchmark/benchmark.md index f0d2c87681..674121deb5 100644 --- a/docs/zh_cn/03-benchmark/benchmark.md +++ b/docs/zh_cn/03-benchmark/benchmark.md @@ -1577,8 +1577,8 @@ GPU: ncnn, TensorRT, PPLNN mAP 0.756 0.756 - - - - + 0.758 + 0.730 - - diff --git a/docs/zh_cn/03-benchmark/supported_models.md b/docs/zh_cn/03-benchmark/supported_models.md index 69f5b48caf..767fae8add 100644 --- a/docs/zh_cn/03-benchmark/supported_models.md +++ b/docs/zh_cn/03-benchmark/supported_models.md @@ -69,6 +69,7 @@ | PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | | CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) | | RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) | +| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) | ## Note diff --git a/mmdeploy/codebase/mmrotate/core/bbox/__init__.py b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py index 2933ca8be3..22ef641430 100644 --- a/mmdeploy/codebase/mmrotate/core/bbox/__init__.py +++ b/mmdeploy/codebase/mmrotate/core/bbox/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .delta_midpointoffset_rbbox_coder import * # noqa: F401,F403 from .delta_xywha_rbbox_coder import * # noqa: F401,F403 +from .transforms import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmrotate/core/bbox/transforms.py b/mmdeploy/codebase/mmrotate/core/bbox/transforms.py new file mode 100644 index 0000000000..61bdefc5d3 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/core/bbox/transforms.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmrotate.core.bbox.transforms import norm_angle + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmrotate.core.bbox.transforms.poly2obb_le90', + backend='tensorrt') +def poly2obb_le90__tensorrt(ctx, polys: torch.Tensor) -> torch.Tensor: + """This is a rewrite for poly2obb to remove NonZero ops. + + Args: + ctx : context of the rewriter. + polys (torch.Tensor): input + + Returns: + torch.Tensor: output + """ + polys = torch.reshape(polys, [-1, 8]) + pt1, pt2, pt3, pt4 = polys[..., :8].chunk(4, 1) + edge1 = torch.sqrt( + torch.pow(pt1[..., 0] - pt2[..., 0], 2) + + torch.pow(pt1[..., 1] - pt2[..., 1], 2)) + edge2 = torch.sqrt( + torch.pow(pt2[..., 0] - pt3[..., 0], 2) + + torch.pow(pt2[..., 1] - pt3[..., 1], 2)) + angles1 = torch.atan2((pt2[..., 1] - pt1[..., 1]), + (pt2[..., 0] - pt1[..., 0])) + angles2 = torch.atan2((pt4[..., 1] - pt1[..., 1]), + (pt4[..., 0] - pt1[..., 0])) + angles = torch.where(edge1 > edge2, angles1, angles2) + angles = norm_angle(angles, 'le90') + x_ctr = (pt1[..., 0] + pt3[..., 0]) / 2.0 + y_ctr = (pt1[..., 1] + pt3[..., 1]) / 2.0 + edges = torch.stack([edge1, edge2], dim=1) + width, _ = torch.max(edges, 1) + height, _ = torch.min(edges, 1) + return torch.stack([x_ctr, y_ctr, width, height, angles], 1) diff --git a/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py index 7c884937ac..336f49aa3c 100644 --- a/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py +++ b/mmdeploy/codebase/mmrotate/core/post_processing/bbox_nms.py @@ -5,7 +5,7 @@ import mmdeploy from mmdeploy.core import FUNCTION_REWRITER, mark -from mmdeploy.mmcv.ops import (ONNXNMSop, ONNXNMSRotatedOp, +from mmdeploy.mmcv.ops import (ONNXNMSop, ONNXNMSRotatedOp, TRTBatchedNMSop, TRTBatchedRotatedNMSop) @@ -127,14 +127,14 @@ def _multiclass_nms_rotated(boxes: Tensor, func_name='mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms.' '_multiclass_nms_rotated', backend='tensorrt') -def multiclass_nms_rotated_static(ctx, - boxes: Tensor, - scores: Tensor, - max_output_boxes_per_class: int = 1000, - iou_threshold: float = 0.5, - score_threshold: float = 0.05, - pre_top_k: int = -1, - keep_top_k: int = -1): +def multiclass_nms_rotated__tensorrt(ctx, + boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1): """Wrapper for `multiclass_nms` with TensorRT. Args: @@ -178,18 +178,14 @@ def multiclass_nms_rotated(*args, **kwargs): _multiclass_nms_rotated(*args, **kwargs) -@mark( - 'fake_multiclass_nms_rotated', - inputs=['boxes', 'scores'], - outputs=['dets', 'labels']) -def fake_multiclass_nms_rotated(boxes: Tensor, - scores: Tensor, - max_output_boxes_per_class: int = 1000, - iou_threshold: float = 0.5, - score_threshold: float = 0.0, - pre_top_k: int = -1, - keep_top_k: int = -1, - version: str = 'le90'): +def _fake_multiclass_nms_rotated(boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.0, + pre_top_k: int = -1, + keep_top_k: int = -1, + version: str = 'le90'): """Fake NMSRotated for multi-class bboxes which use horizontal bboxes for NMS, but return the rotated bboxes result. @@ -220,3 +216,70 @@ def fake_multiclass_nms_rotated(boxes: Tensor, scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k) return dets, labels + + +@mark( + 'fake_multiclass_nms_rotated', + inputs=['boxes', 'scores'], + outputs=['dets', 'labels']) +def fake_multiclass_nms_rotated(*args, **kwargs): + """Wrapper function for `_fake_multiclass_nms_rotated`.""" + return mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms.\ + _fake_multiclass_nms_rotated(*args, **kwargs) + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdeploy.codebase.mmrotate.core.post_processing.bbox_nms.' + '_fake_multiclass_nms_rotated', + backend='tensorrt') +def _fake_multiclass_nms_rotated__tensorrt( + ctx, + boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.0, + pre_top_k: int = -1, + keep_top_k: int = -1, + version: str = 'le90'): + """Wrapper for `multiclass_nms` with TensorRT. + + Args: + ctx (ContextCaller): The context with additional information. + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 5]. + scores (Tensor): The detection scores of shape + [N, num_boxes, num_classes]. + max_output_boxes_per_class (int): Maximum number of output + boxes per class of nms. Defaults to 1000. + iou_threshold (float): IOU threshold of nms. Defaults to 0.5. + score_threshold (float): score threshold of nms. + Defaults to 0.05. + pre_top_k (int): Number of top K boxes to keep before nms. + Defaults to -1. + keep_top_k (int): Number of top K boxes to keep after nms. + Defaults to -1. + + Returns: + tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 6] + and `labels` of shape [N, num_det]. + """ + batch_size = boxes.size(0) + device = boxes.device + hboxes = obb2xyxy(boxes, version) + hboxes = hboxes if hboxes.dim() == 4 else hboxes.unsqueeze(2) + keep_top_k = max_output_boxes_per_class if keep_top_k < 0 else min( + max_output_boxes_per_class, keep_top_k) + if pre_top_k > 512 * 10 or pre_top_k < 0: + pre_top_k = 512 * 10 + + dets, labels, index = TRTBatchedNMSop.apply(hboxes, scores, + int(scores.shape[-1]), + pre_top_k, keep_top_k, + iou_threshold, score_threshold, + -1, True) + dets = torch.cat([boxes, scores], dim=-1) + dets = torch.cat([dets, dets[:, :1, :] * 0], dim=1) + batch_inds = torch.arange(batch_size, device=device).view(-1, 1) + dets = dets[batch_inds, index, :] + + return dets, labels diff --git a/mmdeploy/codebase/mmrotate/models/__init__.py b/mmdeploy/codebase/mmrotate/models/__init__.py index 32a7d21e7a..5fe7e5a1c7 100644 --- a/mmdeploy/codebase/mmrotate/models/__init__.py +++ b/mmdeploy/codebase/mmrotate/models/__init__.py @@ -2,6 +2,7 @@ from .oriented_standard_roi_head import ( oriented_standard_roi_head__simple_test, oriented_standard_roi_head__simple_test_bboxes) +from .roi_extractors import rotated_single_roi_extractor__forward__tensorrt from .rotated_anchor_head import rotated_anchor_head__get_bbox from .rotated_bbox_head import rotated_bbox_head__get_bboxes from .rotated_rpn_head import rotated_rpn_head__get_bboxes @@ -13,5 +14,6 @@ 'rotated_anchor_head__get_bbox', 'rotated_rpn_head__get_bboxes', 'oriented_standard_roi_head__simple_test', 'oriented_standard_roi_head__simple_test_bboxes', - 'rotated_bbox_head__get_bboxes' + 'rotated_bbox_head__get_bboxes', + 'rotated_single_roi_extractor__forward__tensorrt' ] diff --git a/mmdeploy/codebase/mmrotate/models/roi_extractors.py b/mmdeploy/codebase/mmrotate/models/roi_extractors.py new file mode 100644 index 0000000000..f48e0dcf37 --- /dev/null +++ b/mmdeploy/codebase/mmrotate/models/roi_extractors.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.ops import RoIAlignRotated +from torch.autograd import Function + +from mmdeploy.core.optimizers import mark +from mmdeploy.core.rewriters import FUNCTION_REWRITER + + +class MultiLevelRotatedRoiAlign(Function): + """Create MMCVMultiLevelRotatedRoiAlign op. + + This class is used to create a MultiLevelRotatedRoiAlign in ONNX for the + TensorRT backend. + """ + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def symbolic(g, *args): + """Symbolic function for creating onnx op.""" + aligned = args[-1] + featmap_strides = args[-2] + finest_scale = args[-3] + roi_scale_factor = args[-4] + sampling_ratio = args[-5] + clockwise = args[-6] + output_size = args[-7] + inputs = args[:len(featmap_strides)] + rois = args[len(featmap_strides)] + return g.op( + 'mmdeploy::MMCVMultiLevelRotatedRoiAlign', + rois, + *inputs, + output_height_i=output_size[1], + output_width_i=output_size[0], + clockwise_i=clockwise, + sampling_ratio_i=sampling_ratio, + roi_scale_factor_f=roi_scale_factor, + finest_scale_i=finest_scale, + featmap_strides_f=featmap_strides, + aligned_i=aligned) + + @staticmethod + def forward(g, *args): + """Run forward.""" + # aligned = args[-1] + featmap_strides = args[-2] + # finest_scale = args[-3] + # roi_scale_factor = args[-4] + # sampling_ratio = args[-5] + output_size = args[-7] + inputs = args[:len(featmap_strides)] + rois = args[len(featmap_strides)] + + num_proposals = rois.shape[0] + channel = inputs[0].shape[1] + + return rois.new_zeros( + (num_proposals, channel, output_size[1], output_size[0])) + + +@FUNCTION_REWRITER.register_rewriter( + 'mmrotate.models.roi_heads.roi_extractors.' + 'rotate_single_level_roi_extractor.RotatedSingleRoIExtractor.forward', + backend='tensorrt') +@mark( + 'rotated_roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) +def rotated_single_roi_extractor__forward__tensorrt(ctx, + self, + feats, + rois, + roi_scale_factor=None): + """Rewrite `forward` of `RotatedSingleRoIExtractor` for TensorRT backend. + + This function uses MMCVMultiLevelRoiAlign op for TensorRT deployment. + """ + featmap_strides = self.featmap_strides + finest_scale = self.finest_scale + + for roi_layer in self.roi_layers: + assert isinstance(roi_layer, RoIAlignRotated + ), f'{type(roi_layer)} is not supported in TensorRT.' + + roi_layer = self.roi_layers[0] + out_size = roi_layer.output_size + sampling_ratio = roi_layer.sampling_ratio + clockwise = roi_layer.clockwise + aligned = roi_layer.aligned + if roi_scale_factor is None: + roi_scale_factor = 1.0 + + featmap_strides = [float(s) for s in featmap_strides] + return MultiLevelRotatedRoiAlign.apply(*feats, rois, out_size, clockwise, + sampling_ratio, roi_scale_factor, + finest_scale, featmap_strides, + aligned) diff --git a/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py b/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py index 389c672155..adef7b5b90 100644 --- a/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py +++ b/mmdeploy/codebase/mmrotate/models/rotated_rpn_head.py @@ -95,8 +95,7 @@ def rotated_rpn_head__get_bboxes(ctx, if not is_dynamic_flag: anchors = anchors.data - # anchors = anchors.expand_as(bbox_pred) - anchors = anchors.expand(batch_size, -1, anchors.size(-1)) + anchors = anchors.unsqueeze(0) # topk in tensorrt does not support shape 0: _, topk_inds = scores.squeeze(2).topk(pre_topk) - batch_inds = torch.arange( - batch_size, device=device).view(-1, 1).expand_as(topk_inds) - anchors = anchors[batch_inds, topk_inds, :] + batch_inds = torch.arange(batch_size, device=device).unsqueeze(-1) + prior_inds = topk_inds.new_zeros((1, 1)) + anchors = anchors[prior_inds, topk_inds, :] bbox_pred = bbox_pred[batch_inds, topk_inds, :] scores = scores[batch_inds, topk_inds, :] mlvl_valid_bboxes.append(bbox_pred) diff --git a/mmdeploy/mmcv/ops/nms.py b/mmdeploy/mmcv/ops/nms.py index 2071240d9e..1ab303bbaa 100644 --- a/mmdeploy/mmcv/ops/nms.py +++ b/mmdeploy/mmcv/ops/nms.py @@ -147,7 +147,8 @@ def forward(ctx, after_topk: int, iou_threshold: float, score_threshold: float, - background_label_id: int = -1): + background_label_id: int = -1, + return_index: bool = False): """Forward of batched nms. Args: @@ -175,10 +176,13 @@ def forward(ctx, batch_size, num_boxes, num_classes = scores.shape out_boxes = min(num_boxes, after_topk) - return torch.rand(batch_size, out_boxes, - 5).to(scores.device), torch.randint( - 0, num_classes, - (batch_size, out_boxes)).to(scores.device) + ret = (torch.rand(batch_size, out_boxes, 5).to(scores.device), + torch.randint(0, num_classes, + (batch_size, out_boxes)).to(scores.device)) + if return_index: + ret = ret + (torch.randint( + 0, out_boxes, (batch_size, out_boxes)).to(scores.device), ) + return ret @staticmethod def symbolic(g, @@ -189,7 +193,8 @@ def symbolic(g, after_topk: int, iou_threshold: float, score_threshold: float, - background_label_id: int = -1): + background_label_id: int = -1, + return_index: bool = False): """Symbolic function for mmdeploy::TRTBatchedNMS.""" return g.op( 'mmdeploy::TRTBatchedNMS', @@ -203,4 +208,5 @@ def symbolic(g, keep_topk_i=after_topk, is_normalized_i=False, clip_boxes_i=False, - outputs=2) + return_index_i=return_index, + outputs=3 if return_index else 2) diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py index 487cbb6277..e9c8d4936c 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py @@ -117,7 +117,7 @@ def test_multiclass_nms_rotated_with_keep_top_k(pre_top_k): model_inputs = {'boxes': test_boxes, 'scores': test_scores} import mmdeploy.backend.onnxruntime as ort_apis - backend_model = ort_apis.ORTWrapper(onnx_model_path, 'cuda:0', None) + backend_model = ort_apis.ORTWrapper(onnx_model_path, 'cpu', None) output = backend_model.forward(model_inputs) output = backend_model.output_to_list(output) dets = output[0] @@ -205,7 +205,7 @@ def delta2bbox(*args, **kwargs): original_outputs = delta2bbox(rois, deltas, version='le90') # wrap function to nn.Module, enable torch.onnx.export - wrapped_func = WrapFunction(delta2bbox) + wrapped_func = WrapFunction(delta2bbox, version='le90') rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_func, model_inputs={ @@ -270,3 +270,42 @@ def test_fake_multiclass_nms_rotated(): assert rewrite_outputs is not None, 'Got unexpected rewrite '\ 'outputs: {}'.format(rewrite_outputs) + + +@pytest.mark.parametrize('backend_type', [Backend.TENSORRT]) +def test_poly2obb_le90(backend_type: Backend): + check_backend(backend_type) + polys = torch.rand(1, 10, 8) + deploy_cfg = mmcv.Config( + dict( + onnx_config=dict(output_names=None, input_shape=None), + backend_config=dict( + type=backend_type.value, + model_inputs=[ + dict( + input_shapes=dict( + polys=dict( + min_shape=polys.shape, + opt_shape=polys.shape, + max_shape=polys.shape))) + ]), + codebase_config=dict(type='mmrotate', task='RotatedDetection'))) + + # import rewriter + from mmdeploy.codebase import Codebase, import_codebase + import_codebase(Codebase.MMROTATE) + + # wrap function to enable rewrite + def poly2obb_le90(*args, **kwargs): + import mmrotate + return mmrotate.core.bbox.transforms.poly2obb_le90(*args, **kwargs) + + # wrap function to nn.Module, enable torch.onnx.export + wrapped_func = WrapFunction(poly2obb_le90) + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_func, + model_inputs={'polys': polys}, + deploy_cfg=deploy_cfg, + run_with_backend=False) + + assert rewrite_outputs is not None diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index 273500f084..97ed573299 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -931,3 +931,77 @@ def wrapped_function(torch_input, torch_rois): input_names=['input', 'rois'], output_names=['roi_feat'], save_dir=save_dir) + + +@pytest.mark.parametrize('backend', [TEST_TENSORRT]) +@pytest.mark.parametrize( + 'out_size, clockwise, sampling_ratio, roi_scale_factor,' + ' finest_scale, featmap_strides, aligned', + [(tuple([2, 2]), False, 2, 1.0, 2, list([1.0]), 1)]) +def test_multi_level_rotated_roi_align(backend, + out_size, + clockwise, + sampling_ratio, + roi_scale_factor, + finest_scale, + featmap_strides, + aligned, + input_list=None, + save_dir=None): + backend.check_env() + + if input_list is None: + import numpy as np + input = [ + torch.tensor([[[[1., 2., 5., 6.], [3., 4., 7., 8.], + [9., 10., 13., 14.], [11., 12., 15., 16.]]]]) + ] + rois = torch.tensor([[0., 1.5, 1.5, 3., 3., np.pi / 2]]) + expected_result = torch.tensor([[[[7.5625, 1.9375], [10.375, 4.75]]]]) + else: + input = input_list[0] + rois = input_list[1] + expected_result = input_list[2] + input_name = [('input_' + str(i)) for i in range(len(featmap_strides))] + input_name.insert(0, 'rois') + + inputs = [ + onnx.helper.make_tensor_value_info( + input_name[i + 1], onnx.TensorProto.FLOAT, shape=input[i].shape) + for i in range(len(input_name) - 1) + ] + inputs.append( + onnx.helper.make_tensor_value_info( + 'rois', onnx.TensorProto.FLOAT, shape=rois.shape)) + outputs = [ + onnx.helper.make_tensor_value_info( + 'bbox_feats', onnx.TensorProto.FLOAT, shape=expected_result.shape) + ] + node = onnx.helper.make_node( + 'MMCVMultiLevelRotatedRoiAlign', + input_name, ['bbox_feats'], + 'MMCVMultiLevelRotatedRoiAlign_0', + None, + 'mmdeploy', + featmap_strides=featmap_strides, + finest_scale=finest_scale, + output_height=out_size[0], + output_width=out_size[1], + clockwise=clockwise, + roi_scale_factor=roi_scale_factor, + sampling_ratio=sampling_ratio, + aligned=aligned) + graph = onnx.helper.make_graph([node], 'torch-jit-export', inputs, outputs) + onnx_model = onnx.helper.make_model( + graph, producer_name='pytorch', producer_version='1.8') + onnx_model.opset_import[0].version = 11 + onnx_model.opset_import.append( + onnx.onnx_ml_pb2.OperatorSetIdProto(domain='mmdeploy', version=1)) + + backend.run_and_validate( + onnx_model, [rois, *input], + 'multi_level_rotated_roi_align', + input_names=input_name, + output_names=['bbox_feats'], + expected_result=expected_result, + save_dir=save_dir)