diff --git a/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuse_pass.cc b/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuse_pass.cc index 30e511acad6..b9253ec9ef2 100644 --- a/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuse_pass.cc @@ -45,7 +45,7 @@ void ConvElementwiseTreeFusePass::Apply( << " elementwise_type: " << elementwise_type; fusion::ConvElementwiseTreeFuser fuser( conv_type, conv_has_bias, conv_has_prelu_alpha, elementwise_type); - fuser(graph.get()); + fuser.apply_impl(graph.get()); } } } diff --git a/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.cc b/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.cc index e2e72249ef3..9285cae5e26 100644 --- a/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.cc +++ b/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.cc @@ -26,6 +26,7 @@ void ConvElementwiseTreeFuser::BuildPattern() { auto* conv_input = VarNode("conv_input")->assert_is_op_input(conv_type_, "Input")->AsInput(); auto* conv_filter = VarNode("conv_filter") + ->assert_is_persistable_var() ->assert_is_op_input(conv_type_, "Filter") ->AsInput(); auto* elementwise_input = VarNode("elementwise_input") @@ -33,9 +34,10 @@ void ConvElementwiseTreeFuser::BuildPattern() { ->AsInput(); // create intermediate nodes - conv_output_ = VarNode("conv_output") - ->assert_is_op_output(conv_type_, "Output") - ->assert_is_op_input(elementwise_type_, "Y"); + auto* conv_output = VarNode("conv_output") + ->assert_is_op_output(conv_type_, "Output") + ->assert_is_op_input(elementwise_type_, "Y") + ->assert_only_one_output(); // create op nodes // The pass will not been applied if conv1x1 has already applied this pass. @@ -63,12 +65,12 @@ void ConvElementwiseTreeFuser::BuildPattern() { ((!has_act_type) || (has_act_type && act_type == "relu")); }; - conv_ = OpNode("conv", conv_type_) - ->assert_is_op(conv_type_) - ->assert_node_satisfied(conv_teller); - elementwise_ = OpNode("elementwise", elementwise_type_) - ->assert_is_op(elementwise_type_) - ->assert_node_satisfied(elementwise_teller); + auto* conv = OpNode("conv", conv_type_) + ->assert_is_op(conv_type_) + ->assert_node_satisfied(conv_teller); + auto* elementwise = OpNode("elementwise", elementwise_type_) + ->assert_is_op(elementwise_type_) + ->assert_node_satisfied(elementwise_teller); // create output node auto* elementwise_output = VarNode("elementwise_output") @@ -79,18 +81,20 @@ void ConvElementwiseTreeFuser::BuildPattern() { // consider two special cases: conv with bias, conv with prelu alpha std::vector conv_inputs{conv_input, conv_filter}; if (conv_has_bias_) { - auto* conv_bias = - VarNode("conv_bias")->assert_is_op_input(conv_type_, "Bias"); + auto* conv_bias = VarNode("conv_bias") + ->assert_is_op_input(conv_type_, "Bias") + ->assert_is_persistable_var(); conv_inputs.push_back(conv_bias); } if (conv_has_prelu_alpha_) { auto* conv_alpha = VarNode("conv_alpha") ->assert_is_op_input(conv_type_, "Prelu_alpha") + ->assert_is_persistable_var() ->AsInput(); conv_inputs.push_back(conv_alpha); } - conv_->LinksFrom(conv_inputs).LinksTo({conv_output_}); - elementwise_->LinksFrom({elementwise_input, conv_output_}) + conv->LinksFrom(conv_inputs).LinksTo({conv_output}); + elementwise->LinksFrom({elementwise_input, conv_output}) .LinksTo({elementwise_output}); } @@ -145,10 +149,10 @@ void ConvElementwiseTreeFuser::InsertNewNode(SSAGraph* graph, return; } - // NOTE: Mark these node as intermediate at this place. - conv_output_->AsIntermediate(); - conv_->AsIntermediate(); - elementwise_->AsIntermediate(); + // NOTE: push these note to nodes2rm_. + nodes2rm_.insert(matched.at("conv")); + nodes2rm_.insert(matched.at("conv_output")); + nodes2rm_.insert(matched.at("elementwise")); auto op_desc = GenOpDesc(matched); auto conv_op_new = LiteOpRegistry::Global().Create(conv_type_); diff --git a/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.h b/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.h index 776e346c874..5fa305612d9 100644 --- a/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.h +++ b/lite/core/optimizer/mir/fusion/conv_elementwise_tree_fuser.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include "lite/core/optimizer/mir/pattern_matcher_high_api.h" @@ -34,6 +35,17 @@ class ConvElementwiseTreeFuser : public FuseBase { conv_has_prelu_alpha_ = conv_has_prelu_alpha; elementwise_type_ = elementwise_type; } + size_t apply_impl(SSAGraph* graph) { + BuildPattern(); + PerformPatternMatcher(graph); + + for (const auto& matched : key2nodes_) { + InsertNewNode(graph, matched); + } + + GraphSafeRemoveNodes(graph, nodes2rm_); + return key2nodes_.size(); + } void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; @@ -45,9 +57,7 @@ class ConvElementwiseTreeFuser : public FuseBase { bool conv_has_bias_{false}; bool conv_has_prelu_alpha_{false}; std::string elementwise_type_{""}; - PMNode* conv_output_; - PMNode* conv_; - PMNode* elementwise_; + std::set nodes2rm_; }; } // namespace fusion