diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 272eb0c252566..ad6dd3c201019 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,7 +9,7 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const { +bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; @@ -19,8 +19,8 @@ bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node if (!optimizer_utils::IsScalar(input_arg)) return false; const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); if (!tensor_proto) return false; - Initializer init_const{*tensor_proto, graph.ModelPath()}; if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false; + Initializer init_const{*tensor_proto, graph.ModelPath()}; index = *(init_const.data()); axis = 0; // Default value. auto& attrs = node.GetAttributes(); @@ -28,6 +28,7 @@ bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node auto& axis_attr = attrs.at("axis"); if (utils::HasInt(axis_attr)) axis = axis_attr.i(); } + indices_n_dims = tensor_proto->dims_size(); return true; } @@ -79,11 +80,19 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le bool can_fuse = true; bool first_edge = true; int64_t split_axis = 0; + int64_t indices_n_dims = -1; InlinedVector gather_outputs(output_count, nullptr); InlinedVector> nodes_to_fuse; for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) { - int64_t index, axis; - if (!IsSupportedGather(graph, *it, index, axis)) { + int64_t index, axis, dims; + if (!IsSupportedGather(graph, *it, index, axis, dims)) { + can_fuse = false; + break; + } + if (indices_n_dims == -1) { + indices_n_dims = dims; + } else if (indices_n_dims != dims) { + // Not the same number of dimensions (0 or 1) for all scalar indices. can_fuse = false; break; } @@ -125,43 +134,54 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le } InlinedVector split_outputs; - for (size_t i = 0; i < output_count; ++i) { - split_outputs.emplace_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); + bool add_squeeze_node = indices_n_dims == 0; + if (add_squeeze_node) { + for (size_t i = 0; i < output_count; ++i) { + split_outputs.emplace_back( + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); + } } Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {node.MutableOutputDefs()[0]}, split_outputs); + {node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs); split_node.AddAttribute("axis", split_axis); split_node.SetExecutionProviderType(node.GetExecutionProviderType()); - // Squeeze before and after OpSet-13 have different schemas. + // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. int onnx_opset_version = -1; if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); } if (onnx_opset_version < 13) { - for (size_t i = 0; i < output_count; ++i) { - Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); - squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + if (add_squeeze_node) { + for (size_t i = 0; i < output_count; ++i) { + Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", + "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); + squeeze_node.AddAttribute("axes", std::vector{split_axis}); + squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + } } } else { - ONNX_NAMESPACE::TensorProto axes_initializer_proto; - axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer")); - axes_initializer_proto.add_dims(static_cast(1)); - axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - InlinedVector axes_value{split_axis}; - axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); - NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); + if (onnx_opset_version >= 18) { + split_node.AddAttribute("num_outputs", static_cast(output_count)); + } - for (size_t i = 0; i < output_count; ++i) { - Node& squeeze_node = - graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + if (add_squeeze_node) { + ONNX_NAMESPACE::TensorProto axes_initializer_proto; + axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer")); + axes_initializer_proto.add_dims(static_cast(1)); + axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + InlinedVector axes_value{split_axis}; + axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); + NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); + + for (size_t i = 0; i < output_count; ++i) { + Node& squeeze_node = + graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", + "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); + squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + } } } diff --git a/onnxruntime/core/optimizer/gather_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h index bdb8c1be82d5a..44c235915b6cc 100644 --- a/onnxruntime/core/optimizer/gather_fusion.h +++ b/onnxruntime/core/optimizer/gather_fusion.h @@ -20,7 +20,7 @@ class GatherToSplitFusion : public GraphTransformer { Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; private: - bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const; + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const; }; /** diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index b37953b508e5e..3b4b793ffc00a 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -6270,7 +6270,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] ==3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; // OpSet-12 { @@ -6325,6 +6325,128 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } + + // OpSet-18 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* shape_arg = builder.MakeInput({{1}}); + auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); + auto* gather_index_1 = builder.MakeInitializer({1}, {static_cast(0)}); + auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); + auto* gather_index_3 = builder.MakeInitializer({1}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_out_3 = builder.MakeIntermediate(); + auto* transpose_out_1 = builder.MakeOutput(); + auto* transpose_out_2 = builder.MakeOutput(); + auto* transpose_out_3 = builder.MakeOutput(); + + builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(-2)); + builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) + .AddAttribute("axis", static_cast(2)); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + + // OpSet-12 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + + // OpSet-14 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + + // OpSet-18 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } } TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {