From c1203b037e2d20f6eb758120ac18babfb170fab5 Mon Sep 17 00:00:00 2001 From: Tyler Zhu Date: Wed, 18 Nov 2020 16:16:01 +0800 Subject: [PATCH] Add configurable input size for TLT MaskRCNN Plugin Signed-off-by: Tyler Zhu --- .../generateDetectionPlugin.cpp | 40 +++++++++----- .../generateDetectionPlugin.h | 4 +- .../multilevelCropAndResizePlugin.cpp | 21 +++++-- .../multilevelCropAndResizePlugin.h | 2 +- .../multilevelProposeROIPlugin.cpp | 55 +++++++++++-------- .../multilevelProposeROIPlugin.h | 5 +- 6 files changed, 79 insertions(+), 48 deletions(-) diff --git a/plugin/generateDetectionPlugin/generateDetectionPlugin.cpp b/plugin/generateDetectionPlugin/generateDetectionPlugin.cpp index 44d45ccc..e4abc8de 100644 --- a/plugin/generateDetectionPlugin/generateDetectionPlugin.cpp +++ b/plugin/generateDetectionPlugin/generateDetectionPlugin.cpp @@ -16,6 +16,7 @@ #include "generateDetectionPlugin.h" #include "plugin.h" #include +#include using namespace nvinfer1; using namespace plugin; @@ -40,6 +41,7 @@ GenerateDetectionPluginCreator::GenerateDetectionPluginCreator() mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("image_size", nullptr, PluginFieldType::kINT32, 3)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); @@ -62,6 +64,7 @@ const PluginFieldCollection* GenerateDetectionPluginCreator::getFieldNames() IPluginV2Ext* GenerateDetectionPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) { + auto image_size = TLTMaskRCNNConfig::IMAGE_SHAPE; const PluginField* fields = fc->fields; for (int i = 0; i < fc->nbFields; ++i) { @@ -86,8 +89,14 @@ IPluginV2Ext* GenerateDetectionPluginCreator::createPlugin(const char* name, con assert(fields[i].type == PluginFieldType::kFLOAT32); mIOUThreshold = *(static_cast(fields[i].data)); } + if (!strcmp(attrName, "image_size")) + { + assert(fields[i].type == PluginFieldType::kINT32); + const auto dims = static_cast(fields[i].data); + std::copy_n(dims, 3, image_size.d); + } } - return new GenerateDetection(mNbClasses, mKeepTopK, mScoreThreshold, mIOUThreshold); + return new GenerateDetection(mNbClasses, mKeepTopK, mScoreThreshold, mIOUThreshold, image_size); }; IPluginV2Ext* GenerateDetectionPluginCreator::deserializePlugin(const char* name, const void* data, size_t length) @@ -95,11 +104,12 @@ IPluginV2Ext* GenerateDetectionPluginCreator::deserializePlugin(const char* name return new GenerateDetection(data, length); }; -GenerateDetection::GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold) +GenerateDetection::GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold, const nvinfer1::Dims& image_size) : mNbClasses(num_classes) , mKeepTopK(keep_topk) , mScoreThreshold(score_threshold) , mIOUThreshold(iou_threshold) + , mImageSize(image_size) { mBackgroundLabel = 0; assert(mNbClasses > 0); @@ -178,7 +188,7 @@ const char* GenerateDetection::getPluginNamespace() const size_t GenerateDetection::getSerializationSize() const { - return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * 2; + return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * 2 + sizeof(nvinfer1::Dims); }; void GenerateDetection::serialize(void* buffer) const @@ -190,6 +200,7 @@ void GenerateDetection::serialize(void* buffer) const write(d, mIOUThreshold); write(d, mMaxBatchSize); write(d, mAnchorsCnt); + write(d, mImageSize); ASSERT(d == a + getSerializationSize()); }; @@ -202,6 +213,7 @@ GenerateDetection::GenerateDetection(const void* data, size_t length) float iou_threshold = read(d); mMaxBatchSize = read(d); mAnchorsCnt = read(d); + mImageSize = read(d); ASSERT(d == a + length); mNbClasses = num_classes; @@ -264,17 +276,17 @@ int GenerateDetection::enqueue( // refine detection RefineDetectionWorkSpace refDetcWorkspace(batch_size, mAnchorsCnt, mParam, mType); - cudaError_t status = DetectionPostProcess(stream, batch_size, mAnchorsCnt, - static_cast(mRegWeightDevice->mPtr), - static_cast(TLTMaskRCNNConfig::IMAGE_SHAPE.d[1]), // Image Height - static_cast(TLTMaskRCNNConfig::IMAGE_SHAPE.d[2]), // Image Width - DataType::kFLOAT, // mType, - mParam, refDetcWorkspace, workspace, - inputs[1], // inputs[InScore] - inputs[0], // inputs[InDelta], - mValidCnt->mPtr, // inputs[InCountValid], - inputs[2], // inputs[ROI] - detections); + cudaError_t status + = DetectionPostProcess(stream, batch_size, mAnchorsCnt, static_cast(mRegWeightDevice->mPtr), + static_cast(mImageSize.d[1]), // Image Height + static_cast(mImageSize.d[2]), // Image Width + DataType::kFLOAT, // mType, + mParam, refDetcWorkspace, workspace, + inputs[1], // inputs[InScore] + inputs[0], // inputs[InDelta], + mValidCnt->mPtr, // inputs[InCountValid], + inputs[2], // inputs[ROI] + detections); assert(status == cudaSuccess); return status; diff --git a/plugin/generateDetectionPlugin/generateDetectionPlugin.h b/plugin/generateDetectionPlugin/generateDetectionPlugin.h index eeab618a..22618bb0 100644 --- a/plugin/generateDetectionPlugin/generateDetectionPlugin.h +++ b/plugin/generateDetectionPlugin/generateDetectionPlugin.h @@ -35,7 +35,7 @@ namespace plugin class GenerateDetection : public IPluginV2Ext { public: - GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold); + GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold, const nvinfer1::Dims& image_size); GenerateDetection(const void* data, size_t length); @@ -103,6 +103,8 @@ class GenerateDetection : public IPluginV2Ext RefineNMSParameters mParam; std::shared_ptr> mRegWeightDevice; + nvinfer1::Dims mImageSize; + std::string mNameSpace; }; diff --git a/plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.cpp b/plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.cpp index c93a17fe..30ffee31 100644 --- a/plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.cpp +++ b/plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.cpp @@ -16,6 +16,7 @@ #include "multilevelCropAndResizePlugin.h" #include "plugin.h" #include +#include #include @@ -36,6 +37,7 @@ std::vector MultilevelCropAndResizePluginCreator::mPluginAttributes MultilevelCropAndResizePluginCreator::MultilevelCropAndResizePluginCreator() { mPluginAttributes.emplace_back(PluginField("pooled_size", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("image_size", nullptr, PluginFieldType::kINT32, 3)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); @@ -58,6 +60,7 @@ const PluginFieldCollection* MultilevelCropAndResizePluginCreator::getFieldNames IPluginV2Ext* MultilevelCropAndResizePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) { + auto image_size = TLTMaskRCNNConfig::IMAGE_SHAPE; const PluginField* fields = fc->fields; for (int i = 0; i < fc->nbFields; ++i) { @@ -67,8 +70,14 @@ IPluginV2Ext* MultilevelCropAndResizePluginCreator::createPlugin(const char* nam assert(fields[i].type == PluginFieldType::kINT32); mPooledSize = *(static_cast(fields[i].data)); } + if (!strcmp(attrName, "image_size")) + { + assert(fields[i].type == PluginFieldType::kINT32); + const auto dims = static_cast(fields[i].data); + std::copy_n(dims, 3, image_size.d); + } } - return new MultilevelCropAndResize(mPooledSize); + return new MultilevelCropAndResize(mPooledSize, image_size); }; IPluginV2Ext* MultilevelCropAndResizePluginCreator::deserializePlugin(const char* name, const void* data, size_t length) @@ -76,16 +85,16 @@ IPluginV2Ext* MultilevelCropAndResizePluginCreator::deserializePlugin(const char return new MultilevelCropAndResize(data, length); }; -MultilevelCropAndResize::MultilevelCropAndResize(int pooled_size) +MultilevelCropAndResize::MultilevelCropAndResize(int pooled_size, const nvinfer1::Dims& image_size) : mPooledSize({pooled_size, pooled_size}) { assert(pooled_size > 0); // shape - mInputHeight = TLTMaskRCNNConfig::IMAGE_SHAPE.d[1]; - mInputWidth = TLTMaskRCNNConfig::IMAGE_SHAPE.d[2]; - //Threshold to P3: Smaller -> P2 - mThresh = (224*224) / (4.0f); + mInputHeight = image_size.d[1]; + mInputWidth = image_size.d[2]; + // Threshold to P3: Smaller -> P2 + mThresh = (224 * 224) / (4.0f); }; int MultilevelCropAndResize::getNbOutputs() const diff --git a/plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.h b/plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.h index 89ceedca..c36de479 100644 --- a/plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.h +++ b/plugin/multilevelCropAndResizePlugin/multilevelCropAndResizePlugin.h @@ -35,7 +35,7 @@ namespace plugin class MultilevelCropAndResize : public IPluginV2Ext { public: - MultilevelCropAndResize(int pooled_size); + MultilevelCropAndResize(int pooled_size, const nvinfer1::Dims& image_size); MultilevelCropAndResize(const void* data, size_t length); diff --git a/plugin/multilevelProposeROI/multilevelProposeROIPlugin.cpp b/plugin/multilevelProposeROI/multilevelProposeROIPlugin.cpp index 2f06fab2..66b18e56 100644 --- a/plugin/multilevelProposeROI/multilevelProposeROIPlugin.cpp +++ b/plugin/multilevelProposeROI/multilevelProposeROIPlugin.cpp @@ -18,6 +18,7 @@ #include "plugin.h" #include #include +#include #include #include @@ -43,6 +44,7 @@ MultilevelProposeROIPluginCreator::MultilevelProposeROIPluginCreator() mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("fg_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("image_size", nullptr, PluginFieldType::kINT32, 3)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); @@ -65,6 +67,7 @@ const PluginFieldCollection* MultilevelProposeROIPluginCreator::getFieldNames() IPluginV2Ext* MultilevelProposeROIPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) { + auto image_size = TLTMaskRCNNConfig::IMAGE_SHAPE; const PluginField* fields = fc->fields; for (int i = 0; i < fc->nbFields; ++i) { @@ -89,8 +92,14 @@ IPluginV2Ext* MultilevelProposeROIPluginCreator::createPlugin(const char* name, assert(fields[i].type == PluginFieldType::kFLOAT32); mIOUThreshold = *(static_cast(fields[i].data)); } + if (!strcmp(attrName, "image_size")) + { + assert(fields[i].type == PluginFieldType::kINT32); + const auto dims = static_cast(fields[i].data); + std::copy_n(dims, 3, image_size.d); + } } - return new MultilevelProposeROI(mPreNMSTopK, mKeepTopK, mFGThreshold, mIOUThreshold); + return new MultilevelProposeROI(mPreNMSTopK, mKeepTopK, mFGThreshold, mIOUThreshold, image_size); }; IPluginV2Ext* MultilevelProposeROIPluginCreator::deserializePlugin(const char* name, const void* data, size_t length) @@ -98,11 +107,12 @@ IPluginV2Ext* MultilevelProposeROIPluginCreator::deserializePlugin(const char* n return new MultilevelProposeROI(data, length); }; -MultilevelProposeROI::MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold) +MultilevelProposeROI::MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold, const nvinfer1::Dims image_size) : mPreNMSTopK(prenms_topk) , mKeepTopK(keep_topk) , mFGThreshold(fg_threshold) , mIOUThreshold(iou_threshold) + , mImageSize(image_size) { mBackgroundLabel = -1; assert(mPreNMSTopK > 0); @@ -121,7 +131,7 @@ MultilevelProposeROI::MultilevelProposeROI(int prenms_topk, int keep_topk, float mFeatureCnt = TLTMaskRCNNConfig::MAX_LEVEL - TLTMaskRCNNConfig::MIN_LEVEL + 1; - generate_pyramid_anchors(); + generate_pyramid_anchors(mImageSize); }; int MultilevelProposeROI::getNbOutputs() const @@ -224,7 +234,7 @@ const char* MultilevelProposeROI::getPluginNamespace() const size_t MultilevelProposeROI::getSerializationSize() const { - return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * (mFeatureCnt + 1); + return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * (mFeatureCnt + 1) + sizeof(nvinfer1::Dims); }; void MultilevelProposeROI::serialize(void* buffer) const @@ -239,6 +249,7 @@ void MultilevelProposeROI::serialize(void* buffer) const { write(d, mAnchorsCnt[i]); } + write(d, mImageSize); ASSERT(d == a + getSerializationSize()); }; @@ -257,6 +268,7 @@ MultilevelProposeROI::MultilevelProposeROI(const void* data, size_t length) { mAnchorsCnt.push_back(read(d)); } + mImageSize = read(d); ASSERT(d == a + length); mBackgroundLabel = -1; @@ -273,7 +285,7 @@ MultilevelProposeROI::MultilevelProposeROI(const void* data, size_t length) mType = DataType::kFLOAT; - generate_pyramid_anchors(); + generate_pyramid_anchors(mImageSize); }; void MultilevelProposeROI::check_valid_inputs(const nvinfer1::Dims* inputs, int nbInputDims) @@ -329,9 +341,9 @@ Dims MultilevelProposeROI::getOutputDimensions(int index, const Dims* inputs, in return proposals; } -void MultilevelProposeROI::generate_pyramid_anchors() +void MultilevelProposeROI::generate_pyramid_anchors(const nvinfer1::Dims& image_size) { - const auto image_dims = TLTMaskRCNNConfig::IMAGE_SHAPE; + const auto image_dims = image_size; const auto& anchor_scale = TLTMaskRCNNConfig::RPN_ANCHOR_SCALE; const auto& min_level = TLTMaskRCNNConfig::MIN_LEVEL; @@ -388,23 +400,18 @@ int MultilevelProposeROI::enqueue( { MultilevelProposeROIWorkSpace proposal_ws(batch_size, mAnchorsCnt[i], mPreNMSTopK, mParam, mType); - status = MultilevelPropose(stream, - batch_size, - mAnchorsCnt[i], - mPreNMSTopK, - static_cast(mRegWeightDevice->mPtr), - static_cast(TLTMaskRCNNConfig::IMAGE_SHAPE.d[1]), //Input Height - static_cast(TLTMaskRCNNConfig::IMAGE_SHAPE.d[2]), - DataType::kFLOAT, // mType, - mParam, - proposal_ws, - workspace + kernel_workspace_offset, - inputs[2*i + 1], // inputs[object_score], - inputs[2*i], // inputs[bbox_delta] - mValidCnt->mPtr, - mAnchorBoxesDevice[i]->mPtr, // inputs[anchors] - mTempScores[i]->mPtr, //temp scores [batch_size, topk, 1] - mTempBboxes[i]->mPtr); //temp + status = MultilevelPropose(stream, batch_size, mAnchorsCnt[i], mPreNMSTopK, + static_cast(mRegWeightDevice->mPtr), + static_cast(mImageSize.d[1]), // Input Height + static_cast(mImageSize.d[2]), + DataType::kFLOAT, // mType, + mParam, proposal_ws, workspace + kernel_workspace_offset, + inputs[2 * i + 1], // inputs[object_score], + inputs[2 * i], // inputs[bbox_delta] + mValidCnt->mPtr, + mAnchorBoxesDevice[i]->mPtr, // inputs[anchors] + mTempScores[i]->mPtr, // temp scores [batch_size, topk, 1] + mTempBboxes[i]->mPtr); // temp assert(status == cudaSuccess); kernel_workspace_offset += proposal_ws.totalSize; } diff --git a/plugin/multilevelProposeROI/multilevelProposeROIPlugin.h b/plugin/multilevelProposeROI/multilevelProposeROIPlugin.h index 98306435..6dc62f6f 100644 --- a/plugin/multilevelProposeROI/multilevelProposeROIPlugin.h +++ b/plugin/multilevelProposeROI/multilevelProposeROIPlugin.h @@ -34,7 +34,7 @@ namespace plugin class MultilevelProposeROI : public IPluginV2Ext { public: - MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold); + MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold, const nvinfer1::Dims image_size); MultilevelProposeROI(const void* data, size_t length); @@ -88,7 +88,7 @@ class MultilevelProposeROI : public IPluginV2Ext private: void check_valid_inputs(const nvinfer1::Dims* inputs, int nbInputDims); - void generate_pyramid_anchors(); + void generate_pyramid_anchors(const nvinfer1::Dims& image_size); int mBackgroundLabel; int mPreNMSTopK; @@ -111,6 +111,7 @@ class MultilevelProposeROI : public IPluginV2Ext float** mDeviceBboxes; std::shared_ptr> mRegWeightDevice; + nvinfer1::Dims mImageSize; nvinfer1::DataType mType; RefineNMSParameters mParam;