Skip to content

Commit

Permalink
Extend Pad Fusion for AveragePool (#21556)
Browse files Browse the repository at this point in the history
### Description
This extends the existing pad_fusion for AveragePool operator i.e. fuse
Pad if it is followed by AveragePool operator.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
sumitsays authored and prathikr committed Aug 7, 2024
1 parent 2bc6192 commit 3383889
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion onnxruntime/core/optimizer/pad_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace onnxruntime {
* It matches following pattern:
* Pad
* |
* Conv/MaxPool
* Conv/MaxPool/AveragePool
*/
bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
// if Pad has input axis, don't fuse it.
Expand All @@ -28,6 +28,7 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log

const Node& child_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/pad_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
namespace onnxruntime {
/*
* This fusion submerges a Pad operator to it's child
* Conv or MaxPool operator, if and only if PadFusion::SatisfyCondition()
* Conv or MaxPool or AveragePool operator, if and only if PadFusion::SatisfyCondition()
* is true.
*/
class PadFusion : public RewriteRule {
Expand Down

0 comments on commit 3383889

Please sign in to comment.