Skip to content

Commit

Permalink
fix: Pad/AveragePool fusion (microsoft#23190)
Browse files Browse the repository at this point in the history
### Description
Fusing Pad & AveragePool requires AveragePool to use
`count_include_pad=1`. If the AveragePool already set some padding and
`count_include_pad=0`, fusion can't happen.

This PR adds a condition to perform fusion depending on those
attributes. If fusion occurs, `count_include_pad` is always set to `1`.

### Motivation and Context
Fix microsoft#22177 (mislabelled as a performance issue but there's an actual bug
in the implementation)
Bug introduced in microsoft#21556
  • Loading branch information
mayeut authored and tarekziade committed Jan 10, 2025
1 parent 58c1a7d commit b4e2789
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 2 deletions.
29 changes: 27 additions & 2 deletions onnxruntime/core/optimizer/pad_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace onnxruntime {

bool VerifyNotCastChild(const Node& child_node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
return false;
}
Expand All @@ -31,11 +31,32 @@ bool VerifyNotCastChild(const Node& child_node) {
return false;
}

if (child_node.OpType() == "AveragePool") {
// in case there's already padding and count_include_pad is 0, fusion can't be performed
auto has_pad = false;
if (child_node.GetAttributes().find("pads") != child_node.GetAttributes().end()) {
auto const& pads_values = child_node.GetAttributes().at("pads").ints();
if (!pads_values.empty()) {
has_pad = std::any_of(pads_values.begin(), pads_values.end(), [](int64_t value) { return value != 0; });
}
}
if (has_pad && child_node.GetAttributes().find("count_include_pad") != child_node.GetAttributes().end()) {
if (child_node.GetAttributes().at("count_include_pad").i() == 0) {
return false;
}
}
}

return true;
}

void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) {
auto reset_pads = true;
if (child_node.GetAttributes().find("pads") != child_node.GetAttributes().end()) {
/* pads can be empty, overwrite pads attribute in this case */
reset_pads = child_node.GetAttributes().at("pads").ints().empty();
}
if (reset_pads) {
std::vector<int64_t> pads(pads_size - 4, 0);
child_node.AddAttribute("pads", pads);
}
Expand All @@ -49,6 +70,10 @@ void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_v
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
}

if (child_node.OpType() == "AveragePool") {
child_node.AddAttribute("count_include_pad", static_cast<int64_t>(1));
}
}
/*
* Before:
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,128 @@ TEST_F(GraphTransformationTests, FusePadWithMaxPoolOpsetLessThan11) {
}
}

TEST_F(GraphTransformationTests, FusePadWithAvgPool) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-avgpool.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

std::vector<int64_t> expected_pads;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "Pad") {
auto const& pads_proto = node.GetAttributes().at("pads").ints();
gsl::span<const int64_t> pads_values = gsl::make_span(pads_proto.data(), pads_proto.size());
expected_pads.resize(pads_values.size() - 4);
for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) {
expected_pads[index] = pads_values[pads_index];
expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)];
}
}
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 0);
ASSERT_EQ(op_to_count["AveragePool"], 1);

for (auto& node : graph.Nodes()) {
if (node.OpType() == "AveragePool") {
auto const& child_pads = node.GetAttributes().at("pads").ints();
auto const& count_include_pad = node.GetAttributes().at("count_include_pad");
ASSERT_NE(count_include_pad.i(), 0) << "fusion should ensure count_include_pad!=0";
ASSERT_EQ(child_pads.size(), static_cast<int32_t>(expected_pads.size()))
<< "fusion should produce the same size of pads integer as the AvgPool node";
for (uint32_t index = 0; index < expected_pads.size(); index++) {
ASSERT_EQ(expected_pads[index], child_pads.Get(index))
<< "fusion does not produce correct padding value";
}
}
}
}

TEST_F(GraphTransformationTests, FusePadWithAvgPoolWithPad) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-avgpool_with_pad.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

std::vector<int64_t> expected_pads;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "Pad") {
auto const& pads_proto = node.GetAttributes().at("pads").ints();
gsl::span<const int64_t> pads_values = gsl::make_span(pads_proto.data(), pads_proto.size());
expected_pads.resize(pads_values.size() - 4);

for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) {
expected_pads[index] = pads_values[pads_index];
expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)];
}
} else if (node.OpType() == "AveragePool") {
auto const& child_pads = node.GetAttributes().at("pads").ints();
for (uint32_t index = 0; index < expected_pads.size(); index++) {
expected_pads[index] += child_pads.Get(index);
}
}
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 0);
ASSERT_EQ(op_to_count["AveragePool"], 1);

for (auto& node : graph.Nodes()) {
if (node.OpType() == "AveragePool") {
auto const& child_pads = node.GetAttributes().at("pads").ints();
auto const& count_include_pad = node.GetAttributes().at("count_include_pad");
ASSERT_NE(count_include_pad.i(), 0) << "fusion should ensure count_include_pad!=0";
ASSERT_EQ(child_pads.size(), static_cast<int32_t>(expected_pads.size()))
<< "fusion should produce the same size of pads integer as the AvgPool node";
for (uint32_t index = 0; index < expected_pads.size(); index++) {
ASSERT_EQ(expected_pads[index], child_pads.Get(index))
<< "fusion does not produce correct padding value";
}
}
}
}

// should not fuse
TEST_F(GraphTransformationTests, FusePadWithAvgPoolWithPadNoInclude) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-avgpool_with_pad-nofuse.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["Pad"], 1);
ASSERT_EQ(op_to_count["AveragePool"], 1);
}

TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx";

Expand Down
68 changes: 68 additions & 0 deletions onnxruntime/test/testdata/transform/fusion/fuse-pad-avgpool-gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from pathlib import Path

import numpy as np
import onnx

HERE = Path(__file__).parent.resolve(strict=True)
TEST = False

if TEST:
import onnxruntime


def generate_fuse_pad_avgpool():
parameters = {
"fuse-pad-avgpool": (
{},
[[1.333333, 2.333333, 1.777778], [3.0, 5.0, 3.666667], [2.666667, 4.333333, 3.111111]],
),
"fuse-pad-avgpool_with_pad": (
{"pads": [1, 1, 0, 0], "count_include_pad": 1},
[
[0.111111, 0.333333, 0.666667, 0.555556],
[0.555556, 1.333333, 2.333333, 1.777778],
[1.333333, 3.0, 5.0, 3.666667],
[1.222222, 2.666667, 4.333333, 3.111111],
],
),
"fuse-pad-avgpool_with_pad-nofuse": (
{"pads": [1, 1, 0, 0]},
[
[0.25, 0.5, 1.0, 0.833333],
[0.833333, 1.333333, 2.333333, 1.777778],
[2.0, 3.0, 5.0, 3.666667],
[1.833333, 2.666667, 4.333333, 3.111111],
],
),
}
for name in parameters:
model_path = HERE / f"{name}.onnx"
input_ = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 1, 3, 3))
pad = onnx.helper.make_node("Pad", ["input"], ["tp"], mode="constant", pads=[0, 0, 1, 1, 0, 0, 1, 1])
pool = onnx.helper.make_node("AveragePool", ["tp"], ["output"], kernel_shape=[3, 3], **parameters[name][0])
nodes = [pad, pool]
output_shape = (1, 1, 3, 3) if name == "fuse-pad-avgpool" else (1, 1, 4, 4)
output_ = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, output_shape)
graph = onnx.helper.make_graph(nodes, name, [input_], [output_])
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 7)])
onnx.checker.check_model(model)
onnx.save_model(model, model_path)
if TEST:
input_array = np.array([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=np.float32)
expected = np.array(parameters[name][1], dtype=np.float32)
session_options = onnxruntime.SessionOptions()
session_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(model_path, session_options)
out = session.run(["output"], {"input": input_array})
actual = out[0].squeeze()
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=0.0)
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session = onnxruntime.InferenceSession(model_path, session_options)
out = session.run(["output"], {"input": input_array})
actual = out[0].squeeze()
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=0.0)


if __name__ == "__main__":
generate_fuse_pad_avgpool()
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit b4e2789

Please sign in to comment.