forked from open-mmlab/mmdetection3d
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add trt instance norm plugin (open-mmlab#16)
* add trt instance norm plugin * last line empty * fix clang format * fix grid_sample clang format * remove redundant * fix lint * refine codes * fix clang format * clang format * clang format * clang format
- Loading branch information
Showing
12 changed files
with
418 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
200 changes: 200 additions & 0 deletions
200
backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
// Modified from: | ||
// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.cpp | ||
|
||
#include "trt_instance_norm.hpp" | ||
|
||
#include <cuda_fp16.h> | ||
|
||
#include <stdexcept> | ||
|
||
#include "trt_serialize.hpp" | ||
|
||
using namespace nvinfer1; | ||
|
||
namespace mmlab { | ||
namespace { | ||
constexpr const char* PLUGIN_VERSION{"1"}; | ||
constexpr const char* PLUGIN_NAME{"TRTInstanceNormalization"}; | ||
} // namespace | ||
|
||
TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, | ||
float epsilon) | ||
: TRTPluginBase(name), mEpsilon(epsilon) {} | ||
|
||
TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, | ||
void const* serialData, | ||
size_t serialLength) | ||
: TRTPluginBase(name) { | ||
deserialize_value(&serialData, &serialLength, &mEpsilon); | ||
} | ||
|
||
TRTInstanceNormalization::~TRTInstanceNormalization() {} | ||
|
||
// TRTInstanceNormalization returns one output. | ||
int TRTInstanceNormalization::getNbOutputs() const { return 1; } | ||
|
||
DimsExprs TRTInstanceNormalization::getOutputDimensions( | ||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, | ||
nvinfer1::IExprBuilder& exprBuilder) { | ||
nvinfer1::DimsExprs output(inputs[0]); | ||
return output; | ||
} | ||
|
||
size_t TRTInstanceNormalization::getWorkspaceSize( | ||
const nvinfer1::PluginTensorDesc* inputs, int nbInputs, | ||
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const { | ||
int n = inputs[0].dims.d[0]; | ||
int c = inputs[0].dims.d[1]; | ||
int elem_size = getElementSize(inputs[1].type); | ||
return getAlignedSize(n * c * elem_size) * 2; | ||
} | ||
|
||
int TRTInstanceNormalization::enqueue( | ||
const nvinfer1::PluginTensorDesc* inputDesc, | ||
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, | ||
void* const* outputs, void* workspace, cudaStream_t stream) { | ||
nvinfer1::Dims input_dims = inputDesc[0].dims; | ||
int n = input_dims.d[0]; | ||
int c = input_dims.d[1]; | ||
int h = input_dims.d[2]; | ||
int w = input_dims.nbDims > 3 ? input_dims.d[3] : 1; | ||
int elem_size = getElementSize(inputDesc[1].type); | ||
|
||
void* n_scales = (void*)workspace; | ||
void* n_bias = (void*)(workspace + getAlignedSize(n * c * elem_size)); | ||
|
||
const void* scales = (const void*)inputs[1]; | ||
const void* bias = (const void*)inputs[2]; | ||
|
||
for (int i = 0; i < n; ++i) { | ||
cudaMemcpyAsync(n_scales + i * c * elem_size, scales, c * elem_size, | ||
cudaMemcpyDeviceToDevice, stream); | ||
cudaMemcpyAsync(n_bias + i * c * elem_size, bias, c * elem_size, | ||
cudaMemcpyDeviceToDevice, stream); | ||
} | ||
|
||
cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, | ||
n * c, 1, 1); | ||
cudnnDataType_t cudnn_dtype{}; | ||
convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype); | ||
cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, | ||
h, w); | ||
cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, | ||
h, w); | ||
float alpha = 1; | ||
float beta = 0; | ||
void const* x_ptr = inputs[0]; | ||
void* y_ptr = outputs[0]; | ||
cudnnSetStream(_cudnn_handle, stream); | ||
// Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical | ||
// overflows (NaNs) for fp32 data in some circumstances. The lower- | ||
// performance CUDNN_BATCHNORM_SPATIAL should be used if this is not | ||
// acceptable. | ||
cudnnBatchNormalizationForwardTraining( | ||
_cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, _x_desc, | ||
x_ptr, _y_desc, y_ptr, _b_desc, n_scales, n_bias, 1., nullptr, nullptr, | ||
mEpsilon, nullptr, nullptr); | ||
return 0; | ||
} | ||
|
||
size_t TRTInstanceNormalization::getSerializationSize() const { | ||
return serialized_size(mEpsilon); | ||
} | ||
|
||
void TRTInstanceNormalization::serialize(void* buffer) const { | ||
serialize_value(&buffer, mEpsilon); | ||
} | ||
|
||
bool TRTInstanceNormalization::supportsFormatCombination( | ||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, | ||
int nbOutputs) { | ||
return ((inOut[pos].type == nvinfer1::DataType::kFLOAT || | ||
inOut[pos].type == nvinfer1::DataType::kHALF) && | ||
inOut[pos].format == nvinfer1::PluginFormat::kLINEAR && | ||
inOut[pos].type == inOut[0].type); | ||
} | ||
|
||
const char* TRTInstanceNormalization::getPluginType() const { | ||
return PLUGIN_NAME; | ||
} | ||
|
||
const char* TRTInstanceNormalization::getPluginVersion() const { | ||
return PLUGIN_VERSION; | ||
} | ||
|
||
IPluginV2DynamicExt* TRTInstanceNormalization::clone() const { | ||
auto* plugin = new TRTInstanceNormalization{mLayerName, mEpsilon}; | ||
plugin->setPluginNamespace(mPluginNamespace.c_str()); | ||
return plugin; | ||
} | ||
|
||
nvinfer1::DataType TRTInstanceNormalization::getOutputDataType( | ||
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { | ||
return inputTypes[0]; | ||
} | ||
|
||
// Attach the plugin object to an execution context and grant the plugin the | ||
// access to some context resource. | ||
void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext, | ||
cublasContext* cublasContext, | ||
IGpuAllocator* gpuAllocator) { | ||
_cudnn_handle = cudnnContext; | ||
cudnnCreateTensorDescriptor(&_b_desc); | ||
cudnnCreateTensorDescriptor(&_x_desc); | ||
cudnnCreateTensorDescriptor(&_y_desc); | ||
} | ||
|
||
// Detach the plugin object from its execution context. | ||
void TRTInstanceNormalization::detachFromContext() { | ||
cudnnDestroyTensorDescriptor(_y_desc); | ||
cudnnDestroyTensorDescriptor(_x_desc); | ||
cudnnDestroyTensorDescriptor(_b_desc); | ||
} | ||
|
||
void TRTInstanceNormalization::configurePlugin( | ||
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, | ||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {} | ||
|
||
// TRTInstanceNormalizationCreator methods | ||
TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() { | ||
mPluginAttributes.clear(); | ||
mPluginAttributes.emplace_back( | ||
PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1)); | ||
|
||
mFC.nbFields = mPluginAttributes.size(); | ||
mFC.fields = mPluginAttributes.data(); | ||
} | ||
|
||
const char* TRTInstanceNormalizationCreator::getPluginName() const { | ||
return PLUGIN_NAME; | ||
} | ||
|
||
const char* TRTInstanceNormalizationCreator::getPluginVersion() const { | ||
return PLUGIN_VERSION; | ||
} | ||
|
||
IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin( | ||
const char* name, const nvinfer1::PluginFieldCollection* fc) { | ||
float epsilon = 1e-5; | ||
const PluginField* fields = fc->fields; | ||
for (int i = 0; i < fc->nbFields; ++i) { | ||
const char* attrName = fields[i].name; | ||
if (!strcmp(attrName, "epsilon")) { | ||
epsilon = *(static_cast<const float*>(fields[i].data)); | ||
} | ||
} | ||
|
||
TRTInstanceNormalization* obj = new TRTInstanceNormalization(name, epsilon); | ||
obj->setPluginNamespace(mNamespace.c_str()); | ||
return obj; | ||
} | ||
|
||
IPluginV2DynamicExt* TRTInstanceNormalizationCreator::deserializePlugin( | ||
const char* name, const void* serialData, size_t serialLength) { | ||
TRTInstanceNormalization* obj = | ||
new TRTInstanceNormalization{name, serialData, serialLength}; | ||
obj->setPluginNamespace(mNamespace.c_str()); | ||
return obj; | ||
} | ||
REGISTER_TENSORRT_PLUGIN(TRTInstanceNormalizationCreator); | ||
} // namespace mmlab |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// Modified from: | ||
// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.h | ||
|
||
#ifndef TRT_INSTANCE_NORMALIZATION_HPP | ||
#define TRT_INSTANCE_NORMALIZATION_HPP | ||
#include <cudnn.h> | ||
|
||
#include <iostream> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "trt_plugin_base.hpp" | ||
|
||
typedef unsigned short half_type; | ||
|
||
namespace mmlab { | ||
class TRTInstanceNormalization final : public TRTPluginBase { | ||
public: | ||
TRTInstanceNormalization(const std::string& name, float epsilon); | ||
|
||
TRTInstanceNormalization(const std::string& name, void const* serialData, | ||
size_t serialLength); | ||
|
||
TRTInstanceNormalization() = delete; | ||
|
||
~TRTInstanceNormalization() override; | ||
|
||
int getNbOutputs() const override; | ||
|
||
// DynamicExt plugins returns DimsExprs class instead of Dims | ||
nvinfer1::DimsExprs getOutputDimensions( | ||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, | ||
nvinfer1::IExprBuilder& exprBuilder) override; | ||
|
||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, | ||
int nbInputs, | ||
const nvinfer1::PluginTensorDesc* outputs, | ||
int nbOutputs) const override; | ||
|
||
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, | ||
const nvinfer1::PluginTensorDesc* outputDesc, | ||
const void* const* inputs, void* const* outputs, void* workspace, | ||
cudaStream_t stream) override; | ||
|
||
size_t getSerializationSize() const override; | ||
|
||
void serialize(void* buffer) const override; | ||
|
||
// DynamicExt plugin supportsFormat update. | ||
bool supportsFormatCombination(int pos, | ||
const nvinfer1::PluginTensorDesc* inOut, | ||
int nbInputs, int nbOutputs) override; | ||
|
||
const char* getPluginType() const override; | ||
|
||
const char* getPluginVersion() const override; | ||
|
||
nvinfer1::IPluginV2DynamicExt* clone() const override; | ||
|
||
nvinfer1::DataType getOutputDataType(int index, | ||
const nvinfer1::DataType* inputTypes, | ||
int nbInputs) const override; | ||
|
||
void attachToContext(cudnnContext* cudnn, cublasContext* cublas, | ||
nvinfer1::IGpuAllocator* allocator) override; | ||
|
||
void detachFromContext() override; | ||
|
||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, | ||
int nbInputs, | ||
const nvinfer1::DynamicPluginTensorDesc* out, | ||
int nbOutputs) override; | ||
|
||
private: | ||
float mEpsilon{}; | ||
cudnnHandle_t _cudnn_handle{}; | ||
cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _b_desc{}; | ||
std::string mPluginNamespace{}; | ||
}; | ||
|
||
class TRTInstanceNormalizationCreator : public TRTPluginCreatorBase { | ||
public: | ||
TRTInstanceNormalizationCreator(); | ||
|
||
~TRTInstanceNormalizationCreator() override = default; | ||
|
||
const char* getPluginName() const override; | ||
|
||
const char* getPluginVersion() const override; | ||
|
||
nvinfer1::IPluginV2DynamicExt* createPlugin( | ||
const char* name, const nvinfer1::PluginFieldCollection* fc) override; | ||
|
||
nvinfer1::IPluginV2DynamicExt* deserializePlugin( | ||
const char* name, const void* serialData, size_t serialLength) override; | ||
}; | ||
} // namespace mmlab | ||
#endif // TRT_INSTANCE_NORMALIZATION_HPP |
Oops, something went wrong.