Skip to content

Commit

Permalink
Add configurable input size for TLT MaskRCNN Plugin (#986)
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Zhu <[email protected]>

Co-authored-by: Tyler Zhu <[email protected]>
  • Loading branch information
Tyler-D and Tyler-D authored Dec 29, 2020
1 parent 1565fe7 commit 0914fec
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 48 deletions.
40 changes: 26 additions & 14 deletions plugin/generateDetectionPlugin/generateDetectionPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "generateDetectionPlugin.h"
#include "plugin.h"
#include <cuda_runtime_api.h>
#include <algorithm>

using namespace nvinfer1;
using namespace plugin;
Expand All @@ -40,6 +41,7 @@ GenerateDetectionPluginCreator::GenerateDetectionPluginCreator()
mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("image_size", nullptr, PluginFieldType::kINT32, 3));

mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
Expand All @@ -62,6 +64,7 @@ const PluginFieldCollection* GenerateDetectionPluginCreator::getFieldNames()

IPluginV2Ext* GenerateDetectionPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
auto image_size = TLTMaskRCNNConfig::IMAGE_SHAPE;
const PluginField* fields = fc->fields;
for (int i = 0; i < fc->nbFields; ++i)
{
Expand All @@ -86,20 +89,27 @@ IPluginV2Ext* GenerateDetectionPluginCreator::createPlugin(const char* name, con
assert(fields[i].type == PluginFieldType::kFLOAT32);
mIOUThreshold = *(static_cast<const float*>(fields[i].data));
}
if (!strcmp(attrName, "image_size"))
{
assert(fields[i].type == PluginFieldType::kINT32);
const auto dims = static_cast<const int32_t*>(fields[i].data);
std::copy_n(dims, 3, image_size.d);
}
}
return new GenerateDetection(mNbClasses, mKeepTopK, mScoreThreshold, mIOUThreshold);
return new GenerateDetection(mNbClasses, mKeepTopK, mScoreThreshold, mIOUThreshold, image_size);
};

IPluginV2Ext* GenerateDetectionPluginCreator::deserializePlugin(const char* name, const void* data, size_t length)
{
return new GenerateDetection(data, length);
};

GenerateDetection::GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold)
GenerateDetection::GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold, const nvinfer1::Dims& image_size)
: mNbClasses(num_classes)
, mKeepTopK(keep_topk)
, mScoreThreshold(score_threshold)
, mIOUThreshold(iou_threshold)
, mImageSize(image_size)
{
mBackgroundLabel = 0;
assert(mNbClasses > 0);
Expand Down Expand Up @@ -178,7 +188,7 @@ const char* GenerateDetection::getPluginNamespace() const

size_t GenerateDetection::getSerializationSize() const
{
return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * 2;
return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * 2 + sizeof(nvinfer1::Dims);
};

void GenerateDetection::serialize(void* buffer) const
Expand All @@ -190,6 +200,7 @@ void GenerateDetection::serialize(void* buffer) const
write(d, mIOUThreshold);
write(d, mMaxBatchSize);
write(d, mAnchorsCnt);
write(d, mImageSize);
ASSERT(d == a + getSerializationSize());
};

Expand All @@ -202,6 +213,7 @@ GenerateDetection::GenerateDetection(const void* data, size_t length)
float iou_threshold = read<float>(d);
mMaxBatchSize = read<int>(d);
mAnchorsCnt = read<int>(d);
mImageSize = read<nvinfer1::Dims3>(d);
ASSERT(d == a + length);

mNbClasses = num_classes;
Expand Down Expand Up @@ -264,17 +276,17 @@ int GenerateDetection::enqueue(

// refine detection
RefineDetectionWorkSpace refDetcWorkspace(batch_size, mAnchorsCnt, mParam, mType);
cudaError_t status = DetectionPostProcess(stream, batch_size, mAnchorsCnt,
static_cast<float*>(mRegWeightDevice->mPtr),
static_cast<float>(TLTMaskRCNNConfig::IMAGE_SHAPE.d[1]), // Image Height
static_cast<float>(TLTMaskRCNNConfig::IMAGE_SHAPE.d[2]), // Image Width
DataType::kFLOAT, // mType,
mParam, refDetcWorkspace, workspace,
inputs[1], // inputs[InScore]
inputs[0], // inputs[InDelta],
mValidCnt->mPtr, // inputs[InCountValid],
inputs[2], // inputs[ROI]
detections);
cudaError_t status
= DetectionPostProcess(stream, batch_size, mAnchorsCnt, static_cast<float*>(mRegWeightDevice->mPtr),
static_cast<float>(mImageSize.d[1]), // Image Height
static_cast<float>(mImageSize.d[2]), // Image Width
DataType::kFLOAT, // mType,
mParam, refDetcWorkspace, workspace,
inputs[1], // inputs[InScore]
inputs[0], // inputs[InDelta],
mValidCnt->mPtr, // inputs[InCountValid],
inputs[2], // inputs[ROI]
detections);

assert(status == cudaSuccess);
return status;
Expand Down
4 changes: 3 additions & 1 deletion plugin/generateDetectionPlugin/generateDetectionPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace plugin
class GenerateDetection : public IPluginV2Ext
{
public:
GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold);
GenerateDetection(int num_classes, int keep_topk, float score_threshold, float iou_threshold, const nvinfer1::Dims& image_size);

GenerateDetection(const void* data, size_t length);

Expand Down Expand Up @@ -103,6 +103,8 @@ class GenerateDetection : public IPluginV2Ext
RefineNMSParameters mParam;
std::shared_ptr<CudaBind<float>> mRegWeightDevice;

nvinfer1::Dims mImageSize;

std::string mNameSpace;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "multilevelCropAndResizePlugin.h"
#include "plugin.h"
#include <cuda_runtime_api.h>
#include <algorithm>

#include <fstream>

Expand All @@ -36,6 +37,7 @@ std::vector<PluginField> MultilevelCropAndResizePluginCreator::mPluginAttributes
MultilevelCropAndResizePluginCreator::MultilevelCropAndResizePluginCreator()
{
mPluginAttributes.emplace_back(PluginField("pooled_size", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("image_size", nullptr, PluginFieldType::kINT32, 3));

mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
Expand All @@ -58,6 +60,7 @@ const PluginFieldCollection* MultilevelCropAndResizePluginCreator::getFieldNames

IPluginV2Ext* MultilevelCropAndResizePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
auto image_size = TLTMaskRCNNConfig::IMAGE_SHAPE;
const PluginField* fields = fc->fields;
for (int i = 0; i < fc->nbFields; ++i)
{
Expand All @@ -67,25 +70,31 @@ IPluginV2Ext* MultilevelCropAndResizePluginCreator::createPlugin(const char* nam
assert(fields[i].type == PluginFieldType::kINT32);
mPooledSize = *(static_cast<const int*>(fields[i].data));
}
if (!strcmp(attrName, "image_size"))
{
assert(fields[i].type == PluginFieldType::kINT32);
const auto dims = static_cast<const int32_t*>(fields[i].data);
std::copy_n(dims, 3, image_size.d);
}
}
return new MultilevelCropAndResize(mPooledSize);
return new MultilevelCropAndResize(mPooledSize, image_size);
};

IPluginV2Ext* MultilevelCropAndResizePluginCreator::deserializePlugin(const char* name, const void* data, size_t length)
{
return new MultilevelCropAndResize(data, length);
};

MultilevelCropAndResize::MultilevelCropAndResize(int pooled_size)
MultilevelCropAndResize::MultilevelCropAndResize(int pooled_size, const nvinfer1::Dims& image_size)
: mPooledSize({pooled_size, pooled_size})
{

assert(pooled_size > 0);
// shape
mInputHeight = TLTMaskRCNNConfig::IMAGE_SHAPE.d[1];
mInputWidth = TLTMaskRCNNConfig::IMAGE_SHAPE.d[2];
//Threshold to P3: Smaller -> P2
mThresh = (224*224) / (4.0f);
mInputHeight = image_size.d[1];
mInputWidth = image_size.d[2];
// Threshold to P3: Smaller -> P2
mThresh = (224 * 224) / (4.0f);
};

int MultilevelCropAndResize::getNbOutputs() const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace plugin
class MultilevelCropAndResize : public IPluginV2Ext
{
public:
MultilevelCropAndResize(int pooled_size);
MultilevelCropAndResize(int pooled_size, const nvinfer1::Dims& image_size);

MultilevelCropAndResize(const void* data, size_t length);

Expand Down
55 changes: 31 additions & 24 deletions plugin/multilevelProposeROI/multilevelProposeROIPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "plugin.h"
#include <cuda_runtime_api.h>
#include <iostream>
#include <algorithm>
#include <math.h>

#include <fstream>
Expand All @@ -43,6 +44,7 @@ MultilevelProposeROIPluginCreator::MultilevelProposeROIPluginCreator()
mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("fg_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1));
mPluginAttributes.emplace_back(PluginField("image_size", nullptr, PluginFieldType::kINT32, 3));

mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
Expand All @@ -65,6 +67,7 @@ const PluginFieldCollection* MultilevelProposeROIPluginCreator::getFieldNames()

IPluginV2Ext* MultilevelProposeROIPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc)
{
auto image_size = TLTMaskRCNNConfig::IMAGE_SHAPE;
const PluginField* fields = fc->fields;
for (int i = 0; i < fc->nbFields; ++i)
{
Expand All @@ -89,20 +92,27 @@ IPluginV2Ext* MultilevelProposeROIPluginCreator::createPlugin(const char* name,
assert(fields[i].type == PluginFieldType::kFLOAT32);
mIOUThreshold = *(static_cast<const float*>(fields[i].data));
}
if (!strcmp(attrName, "image_size"))
{
assert(fields[i].type == PluginFieldType::kINT32);
const auto dims = static_cast<const int32_t*>(fields[i].data);
std::copy_n(dims, 3, image_size.d);
}
}
return new MultilevelProposeROI(mPreNMSTopK, mKeepTopK, mFGThreshold, mIOUThreshold);
return new MultilevelProposeROI(mPreNMSTopK, mKeepTopK, mFGThreshold, mIOUThreshold, image_size);
};

IPluginV2Ext* MultilevelProposeROIPluginCreator::deserializePlugin(const char* name, const void* data, size_t length)
{
return new MultilevelProposeROI(data, length);
};

MultilevelProposeROI::MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold)
MultilevelProposeROI::MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold, const nvinfer1::Dims image_size)
: mPreNMSTopK(prenms_topk)
, mKeepTopK(keep_topk)
, mFGThreshold(fg_threshold)
, mIOUThreshold(iou_threshold)
, mImageSize(image_size)
{
mBackgroundLabel = -1;
assert(mPreNMSTopK > 0);
Expand All @@ -121,7 +131,7 @@ MultilevelProposeROI::MultilevelProposeROI(int prenms_topk, int keep_topk, float

mFeatureCnt = TLTMaskRCNNConfig::MAX_LEVEL - TLTMaskRCNNConfig::MIN_LEVEL + 1;

generate_pyramid_anchors();
generate_pyramid_anchors(mImageSize);
};

int MultilevelProposeROI::getNbOutputs() const
Expand Down Expand Up @@ -224,7 +234,7 @@ const char* MultilevelProposeROI::getPluginNamespace() const

size_t MultilevelProposeROI::getSerializationSize() const
{
return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * (mFeatureCnt + 1);
return sizeof(int) * 2 + sizeof(float) * 2 + sizeof(int) * (mFeatureCnt + 1) + sizeof(nvinfer1::Dims);
};

void MultilevelProposeROI::serialize(void* buffer) const
Expand All @@ -239,6 +249,7 @@ void MultilevelProposeROI::serialize(void* buffer) const
{
write(d, mAnchorsCnt[i]);
}
write(d, mImageSize);
ASSERT(d == a + getSerializationSize());
};

Expand All @@ -257,6 +268,7 @@ MultilevelProposeROI::MultilevelProposeROI(const void* data, size_t length)
{
mAnchorsCnt.push_back(read<int>(d));
}
mImageSize = read<nvinfer1::Dims3>(d);
ASSERT(d == a + length);

mBackgroundLabel = -1;
Expand All @@ -273,7 +285,7 @@ MultilevelProposeROI::MultilevelProposeROI(const void* data, size_t length)

mType = DataType::kFLOAT;

generate_pyramid_anchors();
generate_pyramid_anchors(mImageSize);
};

void MultilevelProposeROI::check_valid_inputs(const nvinfer1::Dims* inputs, int nbInputDims)
Expand Down Expand Up @@ -329,9 +341,9 @@ Dims MultilevelProposeROI::getOutputDimensions(int index, const Dims* inputs, in
return proposals;
}

void MultilevelProposeROI::generate_pyramid_anchors()
void MultilevelProposeROI::generate_pyramid_anchors(const nvinfer1::Dims& image_size)
{
const auto image_dims = TLTMaskRCNNConfig::IMAGE_SHAPE;
const auto image_dims = image_size;

const auto& anchor_scale = TLTMaskRCNNConfig::RPN_ANCHOR_SCALE;
const auto& min_level = TLTMaskRCNNConfig::MIN_LEVEL;
Expand Down Expand Up @@ -388,23 +400,18 @@ int MultilevelProposeROI::enqueue(
{

MultilevelProposeROIWorkSpace proposal_ws(batch_size, mAnchorsCnt[i], mPreNMSTopK, mParam, mType);
status = MultilevelPropose(stream,
batch_size,
mAnchorsCnt[i],
mPreNMSTopK,
static_cast<float*>(mRegWeightDevice->mPtr),
static_cast<float>(TLTMaskRCNNConfig::IMAGE_SHAPE.d[1]), //Input Height
static_cast<float>(TLTMaskRCNNConfig::IMAGE_SHAPE.d[2]),
DataType::kFLOAT, // mType,
mParam,
proposal_ws,
workspace + kernel_workspace_offset,
inputs[2*i + 1], // inputs[object_score],
inputs[2*i], // inputs[bbox_delta]
mValidCnt->mPtr,
mAnchorBoxesDevice[i]->mPtr, // inputs[anchors]
mTempScores[i]->mPtr, //temp scores [batch_size, topk, 1]
mTempBboxes[i]->mPtr); //temp
status = MultilevelPropose(stream, batch_size, mAnchorsCnt[i], mPreNMSTopK,
static_cast<float*>(mRegWeightDevice->mPtr),
static_cast<float>(mImageSize.d[1]), // Input Height
static_cast<float>(mImageSize.d[2]),
DataType::kFLOAT, // mType,
mParam, proposal_ws, workspace + kernel_workspace_offset,
inputs[2 * i + 1], // inputs[object_score],
inputs[2 * i], // inputs[bbox_delta]
mValidCnt->mPtr,
mAnchorBoxesDevice[i]->mPtr, // inputs[anchors]
mTempScores[i]->mPtr, // temp scores [batch_size, topk, 1]
mTempBboxes[i]->mPtr); // temp
assert(status == cudaSuccess);
kernel_workspace_offset += proposal_ws.totalSize;
}
Expand Down
5 changes: 3 additions & 2 deletions plugin/multilevelProposeROI/multilevelProposeROIPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace plugin
class MultilevelProposeROI : public IPluginV2Ext
{
public:
MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold);
MultilevelProposeROI(int prenms_topk, int keep_topk, float fg_threshold, float iou_threshold, const nvinfer1::Dims image_size);

MultilevelProposeROI(const void* data, size_t length);

Expand Down Expand Up @@ -88,7 +88,7 @@ class MultilevelProposeROI : public IPluginV2Ext

private:
void check_valid_inputs(const nvinfer1::Dims* inputs, int nbInputDims);
void generate_pyramid_anchors();
void generate_pyramid_anchors(const nvinfer1::Dims& image_size);

int mBackgroundLabel;
int mPreNMSTopK;
Expand All @@ -111,6 +111,7 @@ class MultilevelProposeROI : public IPluginV2Ext
float** mDeviceBboxes;
std::shared_ptr<CudaBind<float>> mRegWeightDevice;

nvinfer1::Dims mImageSize;
nvinfer1::DataType mType;
RefineNMSParameters mParam;

Expand Down

0 comments on commit 0914fec

Please sign in to comment.