Skip to content

Commit

Permalink
[Feature] better tensorrt cpp code (open-mmlab#11)
Browse files Browse the repository at this point in the history
* better tensorrt cpp code

* fix end of file
  • Loading branch information
grimoire authored Jul 12, 2021
1 parent 342e195 commit 3fa94f4
Show file tree
Hide file tree
Showing 27 changed files with 573 additions and 1,043 deletions.
99 changes: 36 additions & 63 deletions backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
#include <cstring>

#include "kernel.h"
#include "trt_batched_nms_kernel.hpp"
#include "trt_serialize.hpp"

namespace mmlab {
using namespace nvinfer1;
using nvinfer1::plugin::NMSParameters;

Expand All @@ -16,25 +18,22 @@ static const char* NMS_PLUGIN_VERSION{"1"};
static const char* NMS_PLUGIN_NAME{"TRTBatchedNMS"};
} // namespace

TRTBatchedNMSPluginDynamic::TRTBatchedNMSPluginDynamic(NMSParameters params)
: param(params) {}
TRTBatchedNMS::TRTBatchedNMS(const std::string& name, NMSParameters params)
: TRTPluginBase(name), param(params) {}

TRTBatchedNMSPluginDynamic::TRTBatchedNMSPluginDynamic(const void* data,
size_t length) {
TRTBatchedNMS::TRTBatchedNMS(const std::string& name, const void* data,
size_t length)
: TRTPluginBase(name) {
deserialize_value(&data, &length, &param);
deserialize_value(&data, &length, &boxesSize);
deserialize_value(&data, &length, &scoresSize);
deserialize_value(&data, &length, &numPriors);
deserialize_value(&data, &length, &mClipBoxes);
}

int TRTBatchedNMSPluginDynamic::getNbOutputs() const { return 2; }
int TRTBatchedNMS::getNbOutputs() const { return 2; }

int TRTBatchedNMSPluginDynamic::initialize() { return STATUS_SUCCESS; }

void TRTBatchedNMSPluginDynamic::terminate() {}

nvinfer1::DimsExprs TRTBatchedNMSPluginDynamic::getOutputDimensions(
nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) {
ASSERT(nbInputs == 2);
Expand All @@ -60,7 +59,7 @@ nvinfer1::DimsExprs TRTBatchedNMSPluginDynamic::getOutputDimensions(
return ret;
}

size_t TRTBatchedNMSPluginDynamic::getWorkspaceSize(
size_t TRTBatchedNMS::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
size_t batch_size = inputs[0].dims.d[0];
Expand All @@ -75,10 +74,10 @@ size_t TRTBatchedNMSPluginDynamic::getWorkspaceSize(
num_priors, topk, DataType::kFLOAT, DataType::kFLOAT);
}

int TRTBatchedNMSPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
void* const* outputs, void* workSpace, cudaStream_t stream) {
int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs,
void* workSpace, cudaStream_t stream) {
const void* const locData = inputs[0];
const void* const confData = inputs[1];

Expand All @@ -103,26 +102,26 @@ int TRTBatchedNMSPluginDynamic::enqueue(
return 0;
}

size_t TRTBatchedNMSPluginDynamic::getSerializationSize() const {
size_t TRTBatchedNMS::getSerializationSize() const {
// NMSParameters, boxesSize,scoresSize,numPriors
return sizeof(NMSParameters) + sizeof(int) * 3 + sizeof(bool);
}

void TRTBatchedNMSPluginDynamic::serialize(void* buffer) const {
void TRTBatchedNMS::serialize(void* buffer) const {
serialize_value(&buffer, param);
serialize_value(&buffer, boxesSize);
serialize_value(&buffer, scoresSize);
serialize_value(&buffer, numPriors);
serialize_value(&buffer, mClipBoxes);
}

void TRTBatchedNMSPluginDynamic::configurePlugin(
void TRTBatchedNMS::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) {
// Validate input arguments
}

bool TRTBatchedNMSPluginDynamic::supportsFormatCombination(
bool TRTBatchedNMS::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) {
if (pos == 3) {
Expand All @@ -133,18 +132,14 @@ bool TRTBatchedNMSPluginDynamic::supportsFormatCombination(
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}

const char* TRTBatchedNMSPluginDynamic::getPluginType() const {
return NMS_PLUGIN_NAME;
}
const char* TRTBatchedNMS::getPluginType() const { return NMS_PLUGIN_NAME; }

const char* TRTBatchedNMSPluginDynamic::getPluginVersion() const {
const char* TRTBatchedNMS::getPluginVersion() const {
return NMS_PLUGIN_VERSION;
}

void TRTBatchedNMSPluginDynamic::destroy() { delete this; }

IPluginV2DynamicExt* TRTBatchedNMSPluginDynamic::clone() const {
auto* plugin = new TRTBatchedNMSPluginDynamic(param);
IPluginV2DynamicExt* TRTBatchedNMS::clone() const {
auto* plugin = new TRTBatchedNMS(mLayerName, param);
plugin->boxesSize = boxesSize;
plugin->scoresSize = scoresSize;
plugin->numPriors = numPriors;
Expand All @@ -153,16 +148,7 @@ IPluginV2DynamicExt* TRTBatchedNMSPluginDynamic::clone() const {
return plugin;
}

void TRTBatchedNMSPluginDynamic::setPluginNamespace(
const char* pluginNamespace) {
mNamespace = pluginNamespace;
}

const char* TRTBatchedNMSPluginDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}

nvinfer1::DataType TRTBatchedNMSPluginDynamic::getOutputDataType(
nvinfer1::DataType TRTBatchedNMS::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
ASSERT(index >= 0 && index < this->getNbOutputs());
if (index == 1) {
Expand All @@ -171,10 +157,9 @@ nvinfer1::DataType TRTBatchedNMSPluginDynamic::getOutputDataType(
return inputTypes[0];
}

void TRTBatchedNMSPluginDynamic::setClipParam(bool clip) { mClipBoxes = clip; }
void TRTBatchedNMS::setClipParam(bool clip) { mClipBoxes = clip; }

TRTBatchedNMSPluginDynamicCreator::TRTBatchedNMSPluginDynamicCreator()
: params{} {
TRTBatchedNMSCreator::TRTBatchedNMSCreator() {
mPluginAttributes.emplace_back(
PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(
Expand All @@ -196,23 +181,19 @@ TRTBatchedNMSPluginDynamicCreator::TRTBatchedNMSPluginDynamicCreator()
mFC.fields = mPluginAttributes.data();
}

const char* TRTBatchedNMSPluginDynamicCreator::getPluginName() const {
const char* TRTBatchedNMSCreator::getPluginName() const {
return NMS_PLUGIN_NAME;
}

const char* TRTBatchedNMSPluginDynamicCreator::getPluginVersion() const {
const char* TRTBatchedNMSCreator::getPluginVersion() const {
return NMS_PLUGIN_VERSION;
}

const PluginFieldCollection*
TRTBatchedNMSPluginDynamicCreator::getFieldNames() {
return &mFC;
}

IPluginV2Ext* TRTBatchedNMSPluginDynamicCreator::createPlugin(
IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(
const char* name, const PluginFieldCollection* fc) {
const PluginField* fields = fc->fields;
bool clipBoxes = true;
nvinfer1::plugin::NMSParameters params{};

for (int i = 0; i < fc->nbFields; ++i) {
const char* attrName = fields[i].name;
Expand Down Expand Up @@ -241,29 +222,21 @@ IPluginV2Ext* TRTBatchedNMSPluginDynamicCreator::createPlugin(
}
}

TRTBatchedNMSPluginDynamic* plugin = new TRTBatchedNMSPluginDynamic(params);
TRTBatchedNMS* plugin = new TRTBatchedNMS(name, params);
plugin->setClipParam(clipBoxes);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}

IPluginV2Ext* TRTBatchedNMSPluginDynamicCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength) {
IPluginV2Ext* TRTBatchedNMSCreator::deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) {
// This object will be deleted when the network is destroyed, which will
// call NMS::destroy()
TRTBatchedNMSPluginDynamic* plugin =
new TRTBatchedNMSPluginDynamic(serialData, serialLength);
TRTBatchedNMS* plugin = new TRTBatchedNMS(name, serialData, serialLength);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}

void TRTBatchedNMSPluginDynamicCreator::setPluginNamespace(
const char* libNamespace) {
mNamespace = libNamespace;
}

const char* TRTBatchedNMSPluginDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}

REGISTER_TENSORRT_PLUGIN(TRTBatchedNMSPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(TRTBatchedNMSCreator);
} // namespace mmlab
53 changes: 10 additions & 43 deletions backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,22 @@
#include <string>
#include <vector>

#include "trt_plugin_helper.hpp"

class TRTBatchedNMSPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
#include "trt_plugin_base.hpp"
namespace mmlab {
class TRTBatchedNMS : public TRTPluginBase {
public:
TRTBatchedNMSPluginDynamic(nvinfer1::plugin::NMSParameters param);
TRTBatchedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param);

TRTBatchedNMSPluginDynamic(const void* data, size_t length);
TRTBatchedNMS(const std::string& name, const void* data, size_t length);

~TRTBatchedNMSPluginDynamic() override = default;
~TRTBatchedNMS() override = default;

int getNbOutputs() const override;

nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;

int initialize() override;

void terminate() override;

size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
Expand Down Expand Up @@ -52,67 +48,38 @@ class TRTBatchedNMSPluginDynamic : public nvinfer1::IPluginV2DynamicExt {

const char* getPluginVersion() const override;

void destroy() override;

nvinfer1::IPluginV2DynamicExt* clone() const override;

nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputType,
int nbInputs) const override;

void setPluginNamespace(const char* libNamespace) override;

const char* getPluginNamespace() const override;

void setClipParam(bool clip);

private:
nvinfer1::plugin::NMSParameters param{};
int boxesSize{};
int scoresSize{};
int numPriors{};
std::string mNamespace;
bool mClipBoxes{};

protected:
// To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
using nvinfer1::IPluginV2DynamicExt::enqueue;
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
};

class TRTBatchedNMSPluginDynamicCreator : public nvinfer1::IPluginCreator {
class TRTBatchedNMSCreator : public TRTPluginCreatorBase {
public:
TRTBatchedNMSPluginDynamicCreator();
TRTBatchedNMSCreator();

~TRTBatchedNMSPluginDynamicCreator() override = default;
~TRTBatchedNMSCreator() override = default;

const char* getPluginName() const override;

const char* getPluginVersion() const override;

const nvinfer1::PluginFieldCollection* getFieldNames() override;

nvinfer1::IPluginV2Ext* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;

nvinfer1::IPluginV2Ext* deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) override;

void setPluginNamespace(const char* libNamespace) override;

const char* getPluginNamespace() const override;

private:
nvinfer1::PluginFieldCollection mFC;
nvinfer1::plugin::NMSParameters params;
std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};

} // namespace mmlab
#endif // TRT_BATCHED_NMS_PLUGIN_CUSTOM_H
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// modify from
// https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin
#include "cuda_runtime_api.h"
#include "kernel.h"
#include "trt_batched_nms_kernel.hpp"

pluginStatus_t nmsInference(
cudaStream_t stream, const int N, const int perBatchBoxesSize,
Expand Down
16 changes: 16 additions & 0 deletions backend_ops/tensorrt/batched_nms/trt_batched_nms_kernel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef TRT_BATCHED_NMS_KERNEL_HPP
#define TRT_BATCHED_NMS_KERNEL_HPP
#include "cuda_runtime_api.h"
#include "kernel.h"

pluginStatus_t nmsInference(
cudaStream_t stream, const int N, const int perBatchBoxesSize,
const int perBatchScoresSize, const bool shareLocation,
const int backgroundLabelId, const int numPredsPerClass,
const int numClasses, const int topK, const int keepTopK,
const float scoreThreshold, const float iouThreshold,
const DataType DT_BBOX, const void* locData, const DataType DT_SCORE,
const void* confData, void* nmsedDets, void* nmsedLabels, void* workspace,
bool isNormalized, bool confSigmoid, bool clipBoxes);

#endif
Loading

0 comments on commit 3fa94f4

Please sign in to comment.