Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] TensorRT Anchor generator plugin #646

Merged
merged 4 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
}
const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); }

virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override {}

virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}

virtual void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override {}

virtual void detachFromContext() TRT_NOEXCEPT override {}

protected:
const std::string mLayerName;
std::string mNamespace;
Expand All @@ -34,10 +49,8 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
protected:
// To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
using nvinfer1::IPluginV2DynamicExt::enqueue;
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
#endif
Expand Down
154 changes: 154 additions & 0 deletions csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_grid_priors.hpp"

#include <assert.h>

#include <chrono>

#include "trt_grid_priors_kernel.hpp"
#include "trt_serialize.hpp"

using namespace nvinfer1;

namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"GridPriorsTRT"};
} // namespace

GridPriorsTRT::GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride)
: TRTPluginBase(name), mStride(stride) {}

GridPriorsTRT::GridPriorsTRT(const std::string name, const void *data, size_t length)
: TRTPluginBase(name) {
deserialize_value(&data, &length, &mStride);
}
GridPriorsTRT::~GridPriorsTRT() {}

nvinfer1::IPluginV2DynamicExt *GridPriorsTRT::clone() const TRT_NOEXCEPT {
GridPriorsTRT *plugin = new GridPriorsTRT(mLayerName, mStride);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs GridPriorsTRT::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
// input[0] == base_anchor
// input[1] == empty_h
// input[2] == empty_w

nvinfer1::DimsExprs ret;
ret.nbDims = 2;
auto area =
exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[2].d[0], *inputs[1].d[0]);
ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *area, *(inputs[0].d[0]));
ret.d[1] = exprBuilder.constant(4);

return ret;
}

bool GridPriorsTRT::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
if (pos == 0) {
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else if (pos - nbInputs == 0) {
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
} else {
return true;
}
}

int GridPriorsTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace,
cudaStream_t stream) TRT_NOEXCEPT {
int num_base_anchors = inputDesc[0].dims.d[0];
int feat_h = inputDesc[1].dims.d[0];
int feat_w = inputDesc[2].dims.d[0];

const void *base_anchor = inputs[0];
void *output = outputs[0];

auto data_type = inputDesc[0].type;
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
trt_grid_priors_impl<float>((float *)base_anchor, (float *)output, num_base_anchors, feat_w,
feat_h, mStride.d[0], mStride.d[1], stream);
break;
default:
return 1;
}

return 0;
}

nvinfer1::DataType GridPriorsTRT::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}

// IPluginV2 Methods
const char *GridPriorsTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *GridPriorsTRT::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }

int GridPriorsTRT::getNbOutputs() const TRT_NOEXCEPT { return 1; }

size_t GridPriorsTRT::getSerializationSize() const TRT_NOEXCEPT { return serialized_size(mStride); }

void GridPriorsTRT::serialize(void *buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, mStride);
;
}

////////////////////// creator /////////////////////////////

GridPriorsTRTCreator::GridPriorsTRTCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *GridPriorsTRTCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *GridPriorsTRTCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }

nvinfer1::IPluginV2 *GridPriorsTRTCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
int stride_w = 1;
int stride_h = 1;

for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);

if (field_name.compare("stride_w") == 0) {
stride_w = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("stride_h") == 0) {
stride_h = static_cast<const int *>(fc->fields[i].data)[0];
}
}
nvinfer1::Dims stride{2, {stride_w, stride_h}};

GridPriorsTRT *plugin = new GridPriorsTRT(name, stride);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *GridPriorsTRTCreator::deserializePlugin(const char *name,
const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new GridPriorsTRT(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(GridPriorsTRTCreator);
} // namespace mmdeploy
66 changes: 66 additions & 0 deletions csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_GRID_PRIORS_HPP
#define TRT_GRID_PRIORS_HPP
#include <cublas_v2.h>

#include <memory>
#include <string>
#include <vector>

#include "trt_plugin_base.hpp"

namespace mmdeploy {
class GridPriorsTRT : public TRTPluginBase {
public:
GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride);

GridPriorsTRT(const std::string name, const void *data, size_t length);

GridPriorsTRT() = delete;

~GridPriorsTRT() TRT_NOEXCEPT override;

// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;

// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;

// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;

private:
nvinfer1::Dims mStride;

cublasHandle_t m_cublas_handle;
};

class GridPriorsTRTCreator : public TRTPluginCreatorBase {
public:
GridPriorsTRTCreator();

const char *getPluginName() const TRT_NOEXCEPT override;

const char *getPluginVersion() const TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_GRID_PRIORS_HPP
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <cuda_fp16.h>

#include "common_cuda_helper.hpp"
#include "trt_grid_priors_kernel.hpp"
#include "trt_plugin_helper.hpp"

template <typename scalar_t>
__global__ void trt_grid_priors_kernel(const scalar_t* base_anchor, scalar_t* output,
int num_base_anchors, int feat_w, int feat_h, int stride_w,
int stride_h) {
// load base anchor into shared memory.
extern __shared__ scalar_t shared_base_anchor[];
for (int i = threadIdx.x; i < num_base_anchors * 4; i += blockDim.x) {
shared_base_anchor[i] = base_anchor[i];
}
__syncthreads();

CUDA_1D_KERNEL_LOOP(index, num_base_anchors * feat_w * feat_h) {
const int a_offset = (index % num_base_anchors) << 2;
const scalar_t w = scalar_t(((index / num_base_anchors) % feat_w) * stride_w);
const scalar_t h = scalar_t((index / (feat_w * num_base_anchors)) * stride_h);

auto out_start = output + index * 4;
out_start[0] = shared_base_anchor[a_offset] + w;
out_start[1] = shared_base_anchor[a_offset + 1] + h;
out_start[2] = shared_base_anchor[a_offset + 2] + w;
out_start[3] = shared_base_anchor[a_offset + 3] + h;
}
}

template <typename scalar_t>
void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors,
int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream) {
trt_grid_priors_kernel<<<GET_BLOCKS(num_base_anchors * feat_w * feat_h), THREADS_PER_BLOCK,
DIVUP(num_base_anchors * 4, 32) * 32 * sizeof(scalar_t), stream>>>(
base_anchor, output, (int)num_base_anchors, (int)feat_w, (int)feat_h, (int)stride_w,
(int)stride_h);
}

template void trt_grid_priors_impl<float>(const float* base_anchor, float* output,
int num_base_anchors, int feat_w, int feat_h,
int stride_w, int stride_h, cudaStream_t stream);
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef TRT_GRID_PRIORS_KERNEL_HPP
#define TRT_GRID_PRIORS_KERNEL_HPP
#include <cuda_runtime.h>

template <typename scalar_t>
void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors,
int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream);

#endif
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .anchor import * # noqa: F401,F403
from .bbox import * # noqa: F401,F403
from .ops import * # noqa: F401,F403
from .post_processing import * # noqa: F401,F403
81 changes: 81 additions & 0 deletions mmdeploy/codebase/mmdet/core/anchor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.onnx import symbolic_helper

from mmdeploy.core import FUNCTION_REWRITER


class GridPriorsTRTOp(torch.autograd.Function):

@staticmethod
def forward(ctx, base_anchors, feat_h, feat_w, stride_h: int,
stride_w: int):
device = base_anchors.device
dtype = base_anchors.dtype
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h

def _meshgrid(x, y, row_major=True):
# use shape instead of len to keep tracing while exporting to onnx
xx = x.repeat(y.shape[0])
yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
if row_major:
return xx, yy
else:
return yy, xx

shift_xx, shift_yy = _meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)

all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4)
# then (0, 1), (0, 2), ...
return all_anchors

@staticmethod
@symbolic_helper.parse_args('v', 'v', 'v', 'i', 'i')
def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int,
stride_w: int):
# zero_h and zero_w is used to provide shape to GridPriorsTRT
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
zero_h = g.op(
'ConstantOfShape',
feat_h,
value_t=torch.tensor([0], dtype=torch.long),
)
zero_w = g.op(
'ConstantOfShape',
feat_w,
value_t=torch.tensor([0], dtype=torch.long),
)
return g.op(
'mmdeploy::GridPriorsTRT',
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
base_anchors,
zero_h,
zero_w,
stride_h_i=stride_h,
stride_w_i=stride_w)


grid_priors_trt = GridPriorsTRTOp.apply


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.core.anchor.anchor_generator.'
'AnchorGenerator.single_level_grid_priors',
backend='tensorrt')
def anchorgenerator__single_level_grid_priors__trt(ctx,
grimoire marked this conversation as resolved.
Show resolved Hide resolved
self,
featmap_size,
level_idx,
dtype=torch.float32,
device='cuda'):
# generate origin func for forward

feat_h, feat_w = featmap_size
if isinstance(feat_h, int) and isinstance(feat_w, int):
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
return ctx.origin_func(self, featmap_size, level_idx, dtype, device)
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
stride_w, stride_h = self.strides[level_idx]
return grid_priors_trt(base_anchors, feat_h, feat_w, stride_h, stride_w)
Loading