Skip to content

Commit

Permalink
Fix Gather to Split optimizer (#14478)
Browse files Browse the repository at this point in the history
### Description
Gather to Split optimizer fails if opset == 18. This PR fixes one bug
and extend unit tests.



### Motivation and Context
The model produced by the optimizer does not follow onnx specifications
with opset 18.
  • Loading branch information
xadupre authored Feb 2, 2023
1 parent 3d8fa4d commit 0bcca7a
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 28 deletions.
72 changes: 46 additions & 26 deletions onnxruntime/core/optimizer/gather_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace onnxruntime {

bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const {
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return false;
Expand All @@ -19,15 +19,16 @@ bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node
if (!optimizer_utils::IsScalar(input_arg)) return false;
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
if (!tensor_proto) return false;
Initializer init_const{*tensor_proto, graph.ModelPath()};
if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false;
Initializer init_const{*tensor_proto, graph.ModelPath()};
index = *(init_const.data<int64_t>());
axis = 0; // Default value.
auto& attrs = node.GetAttributes();
if (attrs.find("axis") != attrs.end()) {
auto& axis_attr = attrs.at("axis");
if (utils::HasInt(axis_attr)) axis = axis_attr.i();
}
indices_n_dims = tensor_proto->dims_size();
return true;
}

Expand Down Expand Up @@ -79,11 +80,19 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
bool can_fuse = true;
bool first_edge = true;
int64_t split_axis = 0;
int64_t indices_n_dims = -1;
InlinedVector<NodeArg*> gather_outputs(output_count, nullptr);
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
int64_t index, axis;
if (!IsSupportedGather(graph, *it, index, axis)) {
int64_t index, axis, dims;
if (!IsSupportedGather(graph, *it, index, axis, dims)) {
can_fuse = false;
break;
}
if (indices_n_dims == -1) {
indices_n_dims = dims;
} else if (indices_n_dims != dims) {
// Not the same number of dimensions (0 or 1) for all scalar indices.
can_fuse = false;
break;
}
Expand Down Expand Up @@ -125,43 +134,54 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
}

InlinedVector<NodeArg*> split_outputs;
for (size_t i = 0; i < output_count; ++i) {
split_outputs.emplace_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
bool add_squeeze_node = indices_n_dims == 0;
if (add_squeeze_node) {
for (size_t i = 0; i < output_count; ++i) {
split_outputs.emplace_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
}
}

Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
{node.MutableOutputDefs()[0]}, split_outputs);
{node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs);
split_node.AddAttribute("axis", split_axis);
split_node.SetExecutionProviderType(node.GetExecutionProviderType());

// Squeeze before and after OpSet-13 have different schemas.
// Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas.
int onnx_opset_version = -1;
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
}

if (onnx_opset_version < 13) {
for (size_t i = 0; i < output_count; ++i) {
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
if (add_squeeze_node) {
for (size_t i = 0; i < output_count; ++i) {
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
}
}
} else {
ONNX_NAMESPACE::TensorProto axes_initializer_proto;
axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer"));
axes_initializer_proto.add_dims(static_cast<int64_t>(1));
axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
InlinedVector<int64_t> axes_value{split_axis};
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);
if (onnx_opset_version >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(output_count));
}

for (size_t i = 0; i < output_count; ++i) {
Node& squeeze_node =
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
if (add_squeeze_node) {
ONNX_NAMESPACE::TensorProto axes_initializer_proto;
axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer"));
axes_initializer_proto.add_dims(static_cast<int64_t>(1));
axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
InlinedVector<int64_t> axes_value{split_axis};
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);

for (size_t i = 0; i < output_count; ++i) {
Node& squeeze_node =
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/gather_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class GatherToSplitFusion : public GraphTransformer {
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const;
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const;
};

/**
Expand Down
124 changes: 123 additions & 1 deletion onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6270,7 +6270,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};

auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] ==3); return Status::OK(); };
auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };

// OpSet-12
{
Expand Down Expand Up @@ -6325,6 +6325,128 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) {
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}

// OpSet-18
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
} else if (node.OpType() == "Squeeze") {
const NodeArg& input_arg = *(node.InputDefs()[1]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
TEST_RETURN_IF_NOT(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
TEST_RETURN_IF_NOT(2 == static_cast<int>(*(init_const.data<int64_t>())));
}
}
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
}

TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{54}});
auto* shape_arg = builder.MakeInput<int64_t>({{1}});
auto* reshape_out = builder.MakeIntermediate<float>({{2, 3, 3, 3}});
auto* gather_index_1 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(0)});
auto* gather_index_2 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(1)});
auto* gather_index_3 = builder.MakeInitializer<int64_t>({1}, {static_cast<int64_t>(2)});
auto* gather_out_1 = builder.MakeIntermediate();
auto* gather_out_2 = builder.MakeIntermediate();
auto* gather_out_3 = builder.MakeIntermediate();
auto* transpose_out_1 = builder.MakeOutput();
auto* transpose_out_2 = builder.MakeOutput();
auto* transpose_out_3 = builder.MakeOutput();

builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out});
builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2})
.AddAttribute("axis", static_cast<int64_t>(-2));
builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3})
.AddAttribute("axis", static_cast<int64_t>(2));
builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector<int64_t>{0, 2, 1});
};

auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); };

// OpSet-12
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
}
}
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}

// OpSet-14
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
}
}
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}

// OpSet-18
{
auto post_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Split") {
auto& attrs = node.GetAttributes();
TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end());
TEST_RETURN_IF_NOT(2 == static_cast<int>(attrs.at("axis").i()));
}
}
return Status::OK();
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
}

TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
Expand Down

0 comments on commit 0bcca7a

Please sign in to comment.