Skip to content

Commit

Permalink
[core] fix swish op bug in conv_activation_fuser,conv_scale_fuser and…
Browse files Browse the repository at this point in the history
… scale_activation_fuser. (#10516)
  • Loading branch information
ddchenhao66 authored May 17, 2024
1 parent c18e583 commit 0741366
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
5 changes: 4 additions & 1 deletion lite/core/optimizer/mir/fusion/conv_activation_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) {
} else if (act_type_ == "tanh") {
op_desc.SetAttr("fuse_tanh", true);
} else if (act_type_ == "swish") {
float scale = act_op_desc.GetAttr<float>("beta");
float scale = 1.0f;
if (act_op_desc.HasAttr("beta")) {
scale = act_op_desc.GetAttr<float>("beta");
}
op_desc.SetAttr("swish_scale", scale);
op_desc.SetAttr("fuse_swish", true);
} else if (act_type_ == "abs") {
Expand Down
5 changes: 4 additions & 1 deletion lite/core/optimizer/mir/fusion/conv_scale_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ void ConvScaleFuser::InsertNewNode(SSAGraph* graph,
} else if (activation_type == "tanh") {
conv_op_desc->SetAttr("fuse_tanh", true);
} else if (activation_type == "swish") {
float scale = scale_op_desc->GetAttr<float>("beta");
float scale = 1.0;
if (scale_op_desc->HasAttr("beta")) {
scale = scale_op_desc->GetAttr<float>("beta");
}
conv_op_desc->SetAttr("swish_scale", scale);
conv_op_desc->SetAttr("fuse_swish", true);
} else if (activation_type == "abs") {
Expand Down
5 changes: 4 additions & 1 deletion lite/core/optimizer/mir/fusion/scale_activation_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ cpp::OpDesc ScaleActivationFuser::GenOpDesc(const key2nodes_t& matched) {
auto prelu_mode = act_op_desc->GetAttr<std::string>("mode");
op_desc.SetAttr("mode", prelu_mode);
} else if (act_type_ == "swish") {
float scale = act_op_desc->GetAttr<float>("beta");
float scale = 1.0;
if (act_op_desc->HasAttr("beta")) {
scale = act_op_desc->GetAttr<float>("beta");
}
op_desc.SetAttr("beta", scale);
}
auto& out_name = matched.at("output")->arg()->name;
Expand Down

0 comments on commit 0741366

Please sign in to comment.