Skip to content

Commit

Permalink
Trt elementwise plugin serialize (#31587)
Browse files Browse the repository at this point in the history
* add serialize unittest

* fix element_op trt plugin serialize bug
  • Loading branch information
shangzhizhou committed Mar 23, 2021
1 parent 23ebfa5 commit 4d5205d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,14 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs,

int ElementwisePluginDynamic::initialize() { return 0; }

size_t ElementwisePluginDynamic::getSerializationSize() const { return 0; }
size_t ElementwisePluginDynamic::getSerializationSize() const {
return SerializedSize(type_.c_str()) + SerializedSize(axis_);
}

void ElementwisePluginDynamic::serialize(void *buffer) const {}
void ElementwisePluginDynamic::serialize(void *buffer) const {
SerializeValue(&buffer, type_.c_str());
SerializeValue(&buffer, axis_);
}

nvinfer1::DimsExprs ElementwisePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
Expand Down
47 changes: 46 additions & 1 deletion paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
public:
explicit ElementwisePluginDynamic(const std::string& type, int axis)
: type_(type), axis_(axis) {}
ElementwisePluginDynamic(void const* serialData, size_t serialLength) {}
ElementwisePluginDynamic(void const* serialData, size_t serialLength) {
const char* elementwise_type;
DeserializeValue(&serialData, &serialLength, &elementwise_type);
type_ = std::string(elementwise_type);
DeserializeValue(&serialData, &serialLength, &axis_);
}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new ElementwisePluginDynamic(type_, axis_);
}
Expand Down Expand Up @@ -138,6 +143,46 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
std::string type_;
int axis_;
};

class ElementwisePluginV2Creator : public nvinfer1::IPluginCreator {
public:
ElementwisePluginV2Creator() {}
const char* getPluginName() const override { return "elementwise_plugin"; }

const char* getPluginVersion() const override { return "1"; }

const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}

nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}

nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new ElementwisePluginDynamic(serial_data, serial_length);
return plugin;
}

void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}

const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}

private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};

REGISTER_TRT_PLUGIN_V2(ElementwisePluginV2Creator);
#endif

} // namespace plugin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,58 @@ def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_mul(x=data1, y=data2)


class TensorRTSubgraphPassElementwiseSerializeTest(
TensorRTSubgraphPassElementwiseTest):
def setUp(self):
super(TensorRTSubgraphPassElementwiseSerializeTest, self).setUp()
self.trt_parameters = TensorRTSubgraphPassElementwiseTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False)

def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
super(TensorRTSubgraphPassElementwiseSerializeTest,
self).test_check_output()


class TensorRTSubgraphPassElementwiseBroadcastDynamicTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data1 = fluid.data(
name="data1", shape=[-1, 3, 64, 64], dtype="float32")
data2 = fluid.data(name="data2", shape=[64, 64], dtype="float32")
eltwise_out = self.append_eltwise(data1, data2)
out = fluid.layers.batch_norm(eltwise_out, is_test=True)
self.feeds = {
"data1": np.random.random([1, 3, 64, 64]).astype("float32"),
"data2": np.random.random([64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassElementwiseBroadcastDynamicTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False)
self.dynamic_shape_params = TensorRTSubgraphPassElementwiseBroadcastDynamicTest.DynamicShapeParam(
{
'data1': [1, 3, 8, 64],
'data2': [8, 64]
}, {'data1': [1, 3, 512, 64],
'data2':
[512, 64]}, {'data1': [1, 3, 256, 64],
'data2': [256, 64]}, False)
self.fetch_list = [out]

def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_add(x=data1, y=data2)

def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))


class TensorRTSubgraphPassShuffleChannelTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
Expand Down

0 comments on commit 4d5205d

Please sign in to comment.