From 4d42365ac229d2959dbb37a35b9c547af1efe617 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Tue, 25 May 2021 20:03:57 +0800 Subject: [PATCH] [Feature]: add TensorRT InstanceNormalization plugin (#1034) * add instancenorm plugin * resolve comments * fix lint * fix typo --- docs/tensorrt_custom_ops.md | 42 +++ docs/tensorrt_plugin.md | 20 +- .../tensorrt/plugins/trt_instance_norm.cpp | 245 ++++++++++++++++++ mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp | 2 + mmcv/ops/csrc/tensorrt/trt_instance_norm.hpp | 120 +++++++++ mmcv/ops/csrc/tensorrt/trt_serialize.hpp | 18 +- mmcv/tensorrt/tensorrt_utils.py | 4 +- tests/test_ops/test_tensorrt.py | 96 ++++++- 8 files changed, 513 insertions(+), 34 deletions(-) create mode 100644 mmcv/ops/csrc/tensorrt/plugins/trt_instance_norm.cpp create mode 100644 mmcv/ops/csrc/tensorrt/trt_instance_norm.hpp diff --git a/docs/tensorrt_custom_ops.md b/docs/tensorrt_custom_ops.md index 7bf369cfb7..0b5b1b83a7 100644 --- a/docs/tensorrt_custom_ops.md +++ b/docs/tensorrt_custom_ops.md @@ -45,6 +45,12 @@ - [Inputs](#inputs-6) - [Outputs](#outputs-6) - [Type Constraints](#type-constraints-6) + - [MMCVInstanceNormalization](#mmcvinstancenormalization) + - [Description](#description-7) + - [Parameters](#parameters-7) + - [Inputs](#inputs-7) + - [Outputs](#outputs-7) + - [Type Constraints](#type-constraints-7) @@ -303,3 +309,39 @@ Returns a namedtuple (`values`, `indices`) where `values` is the cumulative mini ### Type Constraints - T:tensor(float32, Linear) + +## MMCVInstanceNormalization + +### Description + +Carries out instance normalization as described in the paper https://arxiv.org/abs/1607.08022. + +y = scale * (x - mean) / sqrt(variance + epsilon) + B, where mean and variance are computed per instance per channel. + +### Parameters + +| Type | Parameter | Description | +| ------- | --------- | -------------------------------------------------------------------- | +| `float` | `epsilon` | The epsilon value to use to avoid division by zero. Default is 1e-05 | + +### Inputs + +
+
input: T
+
Input data tensor from the previous operator; dimensions for image case are (N x C x H x W), where N is the batch size, C is the number of channels, and H and W are the height and the width of the data. For non image case, the dimensions are in the form of (N x C x D1 x D2 ... Dn), where N is the batch size.
+
scale: T
+
The input 1-dimensional scale tensor of size C.
+
B: T
+
The input 1-dimensional bias tensor of size C.
+
+ +### Outputs + +
+
output: T
+
The output tensor of the same shape as input.
+
+ +### Type Constraints + +- T:tensor(float32, Linear) diff --git a/docs/tensorrt_plugin.md b/docs/tensorrt_plugin.md index 024896da83..e8669943c4 100644 --- a/docs/tensorrt_plugin.md +++ b/docs/tensorrt_plugin.md @@ -24,16 +24,16 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u ## List of TensorRT plugins supported in MMCV -| ONNX Operator | TensorRT Plugin | MMCV Releases | -| :---------------: | :-------------------------------------------------------------: | :-----------: | -| MMCVRoiAlign | [MMCVRoiAlign](./tensorrt_custom_ops.md#mmcvroialign) | 1.2.6 | -| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 | -| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 | -| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 | -| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 | -| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master | -| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master | - +| ONNX Operator | TensorRT Plugin | MMCV Releases | +| :-----------------------: | :-----------------------------------------------------------------------------: | :-----------: | +| MMCVRoiAlign | [MMCVRoiAlign](./tensorrt_custom_ops.md#mmcvroialign) | 1.2.6 | +| ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 | +| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 | +| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 | +| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 | +| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master | +| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master | +| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master | Notes - All plugins listed above are developed on TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0 diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_instance_norm.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_instance_norm.cpp new file mode 100644 index 0000000000..1efdcb3a8d --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_instance_norm.cpp @@ -0,0 +1,245 @@ +// Modified from: +// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.cpp + +#include "trt_instance_norm.hpp" + +#include + +#include + +#include "trt_serialize.hpp" + +using namespace nvinfer1; + +cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, + cudnnDataType_t* cudnn_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + *cudnn_dtype = CUDNN_DATA_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *cudnn_dtype = CUDNN_DATA_HALF; + break; + default: + return CUDNN_STATUS_BAD_PARAM; + } + return CUDNN_STATUS_SUCCESS; +} + +namespace { +constexpr const char* PLUGIN_VERSION{"1"}; +constexpr const char* PLUGIN_NAME{"MMCVInstanceNormalization"}; +} // namespace + +PluginFieldCollection InstanceNormalizationDynamicCreator::mFC{}; +std::vector InstanceNormalizationDynamicCreator::mPluginAttributes; + +InstanceNormalizationDynamic::InstanceNormalizationDynamic( + const std::string& name, float epsilon) + : mLayerName(name), mEpsilon(epsilon) {} + +InstanceNormalizationDynamic::InstanceNormalizationDynamic( + const std::string& name, void const* serialData, size_t serialLength) + : mLayerName(name) { + deserialize_value(&serialData, &serialLength, &mEpsilon); +} + +InstanceNormalizationDynamic::~InstanceNormalizationDynamic() {} + +// InstanceNormalizationDynamic returns one output. +int InstanceNormalizationDynamic::getNbOutputs() const { return 1; } + +DimsExprs InstanceNormalizationDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) { + nvinfer1::DimsExprs output(inputs[0]); + return output; +} + +int InstanceNormalizationDynamic::initialize() { return 0; } + +void InstanceNormalizationDynamic::terminate() {} + +size_t InstanceNormalizationDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const { + int n = inputs[0].dims.d[0]; + int c = inputs[0].dims.d[1]; + int elem_size = mmcv::getElementSize(inputs[1].type); + return mmcv::getAlignedSize(n * c * elem_size) * 2; +} + +int InstanceNormalizationDynamic::enqueue( + const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) { + nvinfer1::Dims input_dims = inputDesc[0].dims; + int n = input_dims.d[0]; + int c = input_dims.d[1]; + int h = input_dims.d[2]; + int w = input_dims.nbDims > 3 ? input_dims.d[3] : 1; + int elem_size = mmcv::getElementSize(inputDesc[1].type); + + void* n_scales = (void*)workspace; + void* n_bias = (void*)(workspace + mmcv::getAlignedSize(n * c * elem_size)); + + const void* scales = (const void*)inputs[1]; + const void* bias = (const void*)inputs[2]; + + for (int i = 0; i < n; ++i) { + cudaMemcpyAsync(n_scales + i * c * elem_size, scales, c * elem_size, + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(n_bias + i * c * elem_size, bias, c * elem_size, + cudaMemcpyDeviceToDevice, stream); + } + + cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, + n * c, 1, 1); + cudnnDataType_t cudnn_dtype{}; + convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype); + cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, + h, w); + cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, + h, w); + float alpha = 1; + float beta = 0; + void const* x_ptr = inputs[0]; + void* y_ptr = outputs[0]; + cudnnSetStream(_cudnn_handle, stream); + // Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical + // overflows (NaNs) for fp32 data in some circumstances. The lower- + // performance CUDNN_BATCHNORM_SPATIAL should be used if this is not + // acceptable. + cudnnBatchNormalizationForwardTraining( + _cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, _x_desc, + x_ptr, _y_desc, y_ptr, _b_desc, n_scales, n_bias, 1., nullptr, nullptr, + mEpsilon, nullptr, nullptr); + return 0; +} + +size_t InstanceNormalizationDynamic::getSerializationSize() const { + return serialized_size(mEpsilon); +} + +void InstanceNormalizationDynamic::serialize(void* buffer) const { + serialize_value(&buffer, mEpsilon); +} + +bool InstanceNormalizationDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, + int nbOutputs) { + return ((inOut[pos].type == nvinfer1::DataType::kFLOAT || + inOut[pos].type == nvinfer1::DataType::kHALF) && + inOut[pos].format == nvinfer1::PluginFormat::kLINEAR && + inOut[pos].type == inOut[0].type); +} + +const char* InstanceNormalizationDynamic::getPluginType() const { + return PLUGIN_NAME; +} + +const char* InstanceNormalizationDynamic::getPluginVersion() const { + return PLUGIN_VERSION; +} + +void InstanceNormalizationDynamic::destroy() { delete this; } + +IPluginV2DynamicExt* InstanceNormalizationDynamic::clone() const { + auto* plugin = new InstanceNormalizationDynamic{mLayerName, mEpsilon}; + plugin->setPluginNamespace(mPluginNamespace.c_str()); + return plugin; +} + +// Set plugin namespace +void InstanceNormalizationDynamic::setPluginNamespace( + const char* pluginNamespace) { + mPluginNamespace = pluginNamespace; +} + +const char* InstanceNormalizationDynamic::getPluginNamespace() const { + return mPluginNamespace.c_str(); +} + +nvinfer1::DataType InstanceNormalizationDynamic::getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { + return inputTypes[0]; +} + +// Attach the plugin object to an execution context and grant the plugin the +// access to some context resource. +void InstanceNormalizationDynamic::attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, + IGpuAllocator* gpuAllocator) { + _cudnn_handle = cudnnContext; + cudnnCreateTensorDescriptor(&_b_desc); + cudnnCreateTensorDescriptor(&_x_desc); + cudnnCreateTensorDescriptor(&_y_desc); +} + +// Detach the plugin object from its execution context. +void InstanceNormalizationDynamic::detachFromContext() { + cudnnDestroyTensorDescriptor(_y_desc); + cudnnDestroyTensorDescriptor(_x_desc); + cudnnDestroyTensorDescriptor(_b_desc); +} + +void InstanceNormalizationDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {} + +// InstanceNormalizationDynamicCreator methods +InstanceNormalizationDynamicCreator::InstanceNormalizationDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back( + PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char* InstanceNormalizationDynamicCreator::getPluginName() const { + return PLUGIN_NAME; +} + +const char* InstanceNormalizationDynamicCreator::getPluginVersion() const { + return PLUGIN_VERSION; +} + +const PluginFieldCollection* +InstanceNormalizationDynamicCreator::getFieldNames() { + return &mFC; +} + +IPluginV2DynamicExt* InstanceNormalizationDynamicCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) { + float epsilon = 1e-5; + const PluginField* fields = fc->fields; + for (int i = 0; i < fc->nbFields; ++i) { + const char* attrName = fields[i].name; + if (!strcmp(attrName, "epsilon")) { + epsilon = *(static_cast(fields[i].data)); + } + } + + InstanceNormalizationDynamic* obj = + new InstanceNormalizationDynamic(name, epsilon); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; +} + +IPluginV2DynamicExt* InstanceNormalizationDynamicCreator::deserializePlugin( + const char* name, const void* serialData, size_t serialLength) { + InstanceNormalizationDynamic* obj = + new InstanceNormalizationDynamic{name, serialData, serialLength}; + obj->setPluginNamespace(mNamespace.c_str()); + return obj; +} + +void InstanceNormalizationDynamicCreator::setPluginNamespace( + const char* libNamespace) { + mNamespace = libNamespace; +} + +const char* InstanceNormalizationDynamicCreator::getPluginNamespace() const { + return mNamespace.c_str(); +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp index ab4ee11e81..81f724f162 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp @@ -3,6 +3,7 @@ #include "trt_cummaxmin.hpp" #include "trt_deform_conv.hpp" #include "trt_grid_sampler.hpp" +#include "trt_instance_norm.hpp" #include "trt_nms.hpp" #include "trt_roi_align.hpp" #include "trt_scatternd.hpp" @@ -14,6 +15,7 @@ REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator); REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); +REGISTER_TENSORRT_PLUGIN(InstanceNormalizationDynamicCreator); extern "C" { bool initLibMMCVInferPlugins() { return true; } diff --git a/mmcv/ops/csrc/tensorrt/trt_instance_norm.hpp b/mmcv/ops/csrc/tensorrt/trt_instance_norm.hpp new file mode 100644 index 0000000000..78060c3901 --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/trt_instance_norm.hpp @@ -0,0 +1,120 @@ +// Modified from: +// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.h + +#ifndef TRT_INSTANCE_NORMALIZATION_PLUGIN_H +#define TRT_INSTANCE_NORMALIZATION_PLUGIN_H +#include + +#include +#include +#include + +#include "trt_plugin_helper.hpp" + +typedef unsigned short half_type; + +class InstanceNormalizationDynamic final + : public nvinfer1::IPluginV2DynamicExt { + public: + InstanceNormalizationDynamic(const std::string& name, float epsilon); + + InstanceNormalizationDynamic(const std::string& name, void const* serialData, + size_t serialLength); + + InstanceNormalizationDynamic() = delete; + + ~InstanceNormalizationDynamic() override; + + int getNbOutputs() const override; + + // DynamicExt plugins returns DimsExprs class instead of Dims + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) override; + + int initialize() override; + + void terminate() override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + + size_t getSerializationSize() const override; + + void serialize(void* buffer) const override; + + // DynamicExt plugin supportsFormat update. + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) override; + + const char* getPluginType() const override; + + const char* getPluginVersion() const override; + + void destroy() override; + + nvinfer1::IPluginV2DynamicExt* clone() const override; + + void setPluginNamespace(const char* pluginNamespace) override; + + const char* getPluginNamespace() const override; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const override; + + void attachToContext(cudnnContext* cudnn, cublasContext* cublas, + nvinfer1::IGpuAllocator* allocator) override; + + void detachFromContext() override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override; + + private: + const std::string mLayerName; + float mEpsilon{}; + cudnnHandle_t _cudnn_handle{}; + cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _b_desc{}; + std::string mPluginNamespace{}; +}; + +class InstanceNormalizationDynamicCreator : public nvinfer1::IPluginCreator { + public: + InstanceNormalizationDynamicCreator(); + + ~InstanceNormalizationDynamicCreator() override = default; + + const char* getPluginName() const override; + + const char* getPluginVersion() const override; + + const nvinfer1::PluginFieldCollection* getFieldNames() override; + + nvinfer1::IPluginV2DynamicExt* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override; + + nvinfer1::IPluginV2DynamicExt* deserializePlugin( + const char* name, const void* serialData, size_t serialLength) override; + + void setPluginNamespace(const char* pluginNamespace) override; + + const char* getPluginNamespace() const override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +#endif // TRT_INSTANCE_NORMALIZATION_PLUGIN_H diff --git a/mmcv/ops/csrc/tensorrt/trt_serialize.hpp b/mmcv/ops/csrc/tensorrt/trt_serialize.hpp index c9e75cbbe7..1f0899fdfe 100644 --- a/mmcv/ops/csrc/tensorrt/trt_serialize.hpp +++ b/mmcv/ops/csrc/tensorrt/trt_serialize.hpp @@ -1,18 +1,6 @@ -/* - * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Modified from: +// https://github.com/NVIDIA/TensorRT/blob/master/plugin/common/serialize.hpp + #ifndef TRT_SERIALIZE_HPP #define TRT_SERIALIZE_HPP #include diff --git a/mmcv/tensorrt/tensorrt_utils.py b/mmcv/tensorrt/tensorrt_utils.py index cf3785e986..b2a22ab3c0 100644 --- a/mmcv/tensorrt/tensorrt_utils.py +++ b/mmcv/tensorrt/tensorrt_utils.py @@ -91,7 +91,9 @@ def parse_data(name, typ): node_dict[output] = new_node nodes.insert(idx, new_node) nodes.remove(node) - + elif node.op_type == 'InstanceNormalization': + # directly change op name + node.op_type = 'MMCVInstanceNormalization' return onnx_model diff --git a/tests/test_ops/test_tensorrt.py b/tests/test_ops/test_tensorrt.py index ddfa68165a..4726630858 100644 --- a/tests/test_ops/test_tensorrt.py +++ b/tests/test_ops/test_tensorrt.py @@ -9,7 +9,7 @@ import torch.nn as nn try: - from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, + from mmcv.tensorrt import (TRTWrapper, is_tensorrt_plugin_loaded, onnx2trt, save_trt_engine) except ImportError: pytest.skip( @@ -95,7 +95,7 @@ def test_roialign(): fp16_mode=fp16_mode, max_workspace_size=max_workspace_size) save_trt_engine(trt_engine, trt_file) - trt_model = TRTWraper(trt_file, ['input', 'rois'], ['roi_feat']) + trt_model = TRTWrapper(trt_file, ['input', 'rois'], ['roi_feat']) with torch.no_grad(): trt_outputs = trt_model({'input': input, 'rois': rois}) @@ -155,7 +155,7 @@ def test_nms(): fp16_mode=fp16_mode, max_workspace_size=max_workspace_size) save_trt_engine(trt_engine, trt_file) - trt_model = TRTWraper(trt_file, ['boxes', 'scores'], ['dets', 'inds']) + trt_model = TRTWrapper(trt_file, ['boxes', 'scores'], ['dets', 'inds']) with torch.no_grad(): trt_outputs = trt_model({'boxes': boxes, 'scores': scores}) @@ -237,7 +237,7 @@ def test_batched_nms(): fp16_mode=fp16_mode, max_workspace_size=max_workspace_size) save_trt_engine(trt_engine, trt_file) - trt_model = TRTWraper(trt_file, input_names, output_names) + trt_model = TRTWrapper(trt_file, input_names, output_names) with torch.no_grad(): trt_outputs = trt_model({ @@ -311,7 +311,7 @@ def func(data): max_workspace_size=max_workspace_size) save_trt_engine(trt_engine, trt_file) - trt_model = TRTWraper(trt_file, input_names, output_names) + trt_model = TRTWrapper(trt_file, input_names, output_names) with torch.no_grad(): trt_outputs = trt_model({'input': data.clone()}) @@ -387,7 +387,7 @@ def test_deform_conv(): max_workspace_size=max_workspace_size) save_trt_engine(trt_engine, trt_file) - trt_model = TRTWraper(trt_file, input_names, output_names) + trt_model = TRTWrapper(trt_file, input_names, output_names) with torch.no_grad(): trt_outputs = trt_model({'input': x.clone()}) @@ -463,7 +463,7 @@ def func(input, grid): max_workspace_size=max_workspace_size) save_trt_engine(trt_engine, trt_file) - trt_model = TRTWraper(trt_file, input_names, output_names) + trt_model = TRTWrapper(trt_file, input_names, output_names) with torch.no_grad(): trt_outputs = trt_model({'input': input.clone(), 'grid': grid.clone()}) @@ -555,7 +555,7 @@ def test_cummin_cummax(func: Callable): save_trt_engine(trt_engine, trt_file) # load and wrap TensorRT model - trt_model = TRTWraper(trt_file) + trt_model = TRTWrapper(trt_file) # remove trt model after loading if os.path.exists(trt_file): @@ -575,3 +575,83 @@ def test_cummin_cummax(func: Callable): torch.testing.assert_allclose(trt_output, pytorch_output) torch.testing.assert_allclose(trt_indices, pytorch_indices) + + +@pytest.mark.parametrize('dynamic_export', [True, False]) +@pytest.mark.parametrize('fp16_mode', [True, False]) +def test_instance_norm(dynamic_export, fp16_mode): + + n, c, h, w = 2, 3, 10, 10 + data = torch.randn(n, c, h, w).cuda() + norm = nn.InstanceNorm2d(c, affine=True) + + wrapped_model = WrapFunction(norm).eval().cuda() + + input_names = ['input'] + output_names = ['output'] + dynamic_axes = None + if dynamic_export: + dynamic_axes = { + 'input': { + 0: 'n', + 2: 'h', + 3: 'w', + }, + 'output': { + 0: 'n', + 2: 'h', + 3: 'w', + }, + } + with torch.no_grad(): + torch.onnx.export( + wrapped_model, (data.clone(), ), + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=11) + + onnx_model = onnx.load(onnx_file) + + # create trt engine and wraper + if dynamic_export: + opt_shape_dict = { + 'input': + [list(data.shape), + list(data.shape), [2 * n, c, 2 * h, 2 * w]], + } + else: + opt_shape_dict = { + 'input': [list(data.shape), + list(data.shape), + list(data.shape)], + } + # trt config + max_workspace_size = 1 << 30 + + trt_engine = onnx2trt( + onnx_model, + opt_shape_dict, + fp16_mode=fp16_mode, + max_workspace_size=max_workspace_size) + + save_trt_engine(trt_engine, trt_file) + trt_model = TRTWrapper(trt_file, input_names, output_names) + + with torch.no_grad(): + trt_outputs = trt_model({'input': data.clone()}) + trt_results = trt_outputs['output'] + + # compute pytorch_output + with torch.no_grad(): + pytorch_results = wrapped_model(data.clone()) + + # allclose + if os.path.exists(onnx_file): + os.remove(onnx_file) + if os.path.exists(trt_file): + os.remove(trt_file) + assert torch.allclose(pytorch_results, trt_results)