Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing node during mem efficient topo sort #20497

Merged
merged 3 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,14 @@ struct GroupNode {
intermediate_args.insert(arg);
}

if (node->GetOutputEdgesCount() == 0) {
for (const NodeArg* arg : node->OutputDefs()) {
output_args.push_back(arg);
}

continue;
}

for (auto output_edge_it = node->OutputEdgesBegin(); output_edge_it != node->OutputEdgesEnd();
++output_edge_it) {
const Node* output_node = &output_edge_it->GetNode();
Expand Down Expand Up @@ -2012,7 +2020,8 @@ void FindBranchGraph(
const InlinedVector<const Node*>& branch_graph_input_nodes,
const InlinedVector<size_t>& backward_node_in_degree,
InlinedVector<const Node*>& branch_graph,
InlinedVector<std::pair<const Node*, size_t>>& branch_subgraph_consumers) {
InlinedVector<std::pair<const Node*, size_t>>& branch_subgraph_consumers,
InlinedVector<const NodeArg*>& branch_subgraph_outputs) {
// Loop through the branch_graph_input_nodes to find the branch subgraphs by its output edges in BFS,
// and find the maximum self_contained subgraph taking the branch_graph_input_nodes as input nodes.
std::queue<const Node*> to_visit_queue;
Expand Down Expand Up @@ -2044,30 +2053,41 @@ void FindBranchGraph(
// At this point, branch_graph is a big subgraph that contains all the nodes that are purely
// triggered by the branch_graph_input_nodes, other graph input/initializers and leaf nodes (for example Constant).
for (const Node* n : branch_graph) {
if (n->GetOutputEdgesCount() == 0) {
// In case the node connect to graph outputs or nothings, append all outputs as the branch subgraph outputs.
for (auto output_def : n->OutputDefs()) {
branch_subgraph_outputs.push_back(output_def);
}
continue;
}

for (auto output_it = n->OutputEdgesBegin(); output_it != n->OutputEdgesEnd(); ++output_it) {
const Node* output_node = &output_it->GetNode();
const size_t dest_in_port = output_it->GetDstArgIndex();
if (std::find(branch_graph.begin(), branch_graph.end(), output_node) == branch_graph.end()) {
branch_subgraph_consumers.push_back({output_node, dest_in_port});
branch_subgraph_outputs.push_back(n->OutputDefs()[output_it->GetSrcArgIndex()]);
}
}
}
}

void TagNodeToAssociatedOutputs(const Graph* graph,
const InlinedHashSet<const Node*>& nodes_to_execute_before_yieldop,
const InlinedVector<std::pair<const Node*, size_t>>& branch_subgraph_consumers,
const InlinedVector<const NodeArg*>& branch_subgraph_outputs,
const InlinedVector<const Node*>& branch_graph,
InlinedVector<GroupNode>& group_node_collection,
InlinedHashMap<const NodeArg*, GroupNode*>& output_arg_to_grouped_node) {
// Reverse DFS from branch graph outputs (e.g. branch_subgraph_consumers) to tag each nodes:
// Reverse DFS from branch graph outputs (e.g. branch_subgraph_outputs) to tag each nodes:
// If one node N contributes to a graph output A, then we will tag A to N.
// If the node N contributes to multiple graph outputs A, B, C, then we will tag the A, B, C to N.
InlinedHashMap<const Node*, std::set<const NodeArg*>> node_to_its_associated_outputs;
node_to_its_associated_outputs.reserve(branch_graph.size());
for (const auto& consumer : branch_subgraph_consumers) {
const NodeArg* output_arg = consumer.first->InputDefs()[consumer.second];
InlinedHashSet<const Node*> handled_branch_subgraph_end_nodes;
for (const auto& output_arg : branch_subgraph_outputs) {
const Node* end_node = graph->GetProducerNode(output_arg->Name());
handled_branch_subgraph_end_nodes.insert(end_node);

InlinedVector<const Node*> end_nodes{end_node};
graph->ReverseDFSFrom(
end_nodes,
Expand Down Expand Up @@ -2097,6 +2117,7 @@ void TagNodeToAssociatedOutputs(const Graph* graph,
group_node_collection.reserve(associated_outputs_to_nodes.size());
for (auto& [associated_outputs, nodes] : associated_outputs_to_nodes) {
group_node_collection.push_back(nodes);

// Flatten the key into NodeArg* for better search.
GroupNode& grouped_node = group_node_collection.back();
for (const auto& output_arg : grouped_node.output_args) {
Expand Down Expand Up @@ -2184,6 +2205,7 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op,

InlinedVector<const Node*> branch_graph_input_nodes;
branch_graph_input_nodes.reserve(num_of_backward_nodes);

PrepareToFindBranchGraph(this,
nodes_to_execute_before_yieldop,
branch_graph_input_nodes,
Expand All @@ -2193,17 +2215,19 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op,
InlinedVector<const Node*> branch_graph;
branch_graph.reserve(num_of_backward_nodes);
InlinedVector<std::pair<const Node*, size_t>> branch_subgraph_consumers;
InlinedVector<const NodeArg*> branch_subgraph_outputs;
FindBranchGraph(branch_graph_input_nodes,
backward_node_in_degree,
branch_graph,
branch_subgraph_consumers);
branch_subgraph_consumers,
branch_subgraph_outputs);

// Cluster the nodes in the branch_graph based on the associated outputs.
InlinedVector<GroupNode> group_node_collection;
InlinedHashMap<const NodeArg*, GroupNode*> output_arg_to_grouped_node;
TagNodeToAssociatedOutputs(this,
nodes_to_execute_before_yieldop,
branch_subgraph_consumers,
branch_subgraph_outputs,
branch_graph,
group_node_collection,
output_arg_to_grouped_node);
Expand Down Expand Up @@ -2247,9 +2271,27 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op,
// For the group nodes that are not outputted, we need to output them.
// Hitting this code path means some nodes are consuming outputs of forward nodes, and their outputs
// are not used by main branch backward nodes.
for (const auto& [output_arg, grouped_node] : output_arg_to_grouped_node) {
InlinedVector<std::pair<const NodeArg*, GroupNode*>>
left_output_arg_to_grouped_node_vector; // To ensure deterministic order.
left_output_arg_to_grouped_node_vector.reserve(output_arg_to_grouped_node.size());
for (auto& [output_arg, grouped_node] : output_arg_to_grouped_node) {
if (!grouped_node->is_outputted) {
OutputGroupedNodes(this, output_arg, output_arg_to_grouped_node, node_orders, topo_order);
left_output_arg_to_grouped_node_vector.push_back({output_arg, grouped_node});
}
}

if (!left_output_arg_to_grouped_node_vector.empty()) {
// Sort to ensure deterministic order.
std::sort(left_output_arg_to_grouped_node_vector.begin(), left_output_arg_to_grouped_node_vector.end(),
[](const std::pair<const NodeArg*, GroupNode*>& a, const std::pair<const NodeArg*, GroupNode*>& b) {
return a.first->Name() < b.first->Name();
});
for (const auto& pair : left_output_arg_to_grouped_node_vector) {
const NodeArg* output_arg = pair.first;
GroupNode* grouped_node = pair.second;
if (!grouped_node->is_outputted) {
OutputGroupedNodes(this, output_arg, output_arg_to_grouped_node, node_orders, topo_order);
}
}
}

Expand Down
79 changes: 79 additions & 0 deletions onnxruntime/test/ir/graph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2291,6 +2291,85 @@
}
}

TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_SubgraphGeneratingNodeHavingNoConsumers) {
Model model("graph_1", false, *logger_);
auto& graph = model.MainGraph();

/*
|
node_0 (Identity)
/ \ \
node_1 (Identity) \ Identity
| | \_____graph_output_0
node_4 (Identity) |
| |
YieldOp recompute_node_1
\ / \
node_1_grad (Merge) Identity
| |
graph_output_1
*/

TypeProto tensor_int32;
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);

auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32);
auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32);
auto& graph_output_0_identity = graph.GetOrCreateNodeArg("graphoutput_0_identity_out_1", &tensor_int32);
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32);
auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32);
auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32);
auto& output_arg5 = graph.GetOrCreateNodeArg("node_yield_out_1", &tensor_int32);
auto& output_arg6 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32);

graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0});
graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1});
graph.AddNode("graph_output_0_identity", "Identity_Fake", "graph output 0 identity", {&output_arg0}, {&graph_output_0_identity});

Check warning on line 2328 in onnxruntime/test/ir/graph_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/ir/graph_test.cc:2328: Lines should be <= 120 characters long [whitespace/line_length] [2]
graph.AddNode("recompute_node_1", "Identity_Fake", "recompute node 1", {&output_arg0}, {&output_arg2});

auto& graph_output1_identity = graph.GetOrCreateNodeArg("graphoutput_1_identity_out_1", &tensor_int32);
graph.AddNode("graph_output_1_identity", "Identity_Fake", "graph output 1 identity", {&output_arg2}, {&graph_output1_identity});

Check warning on line 2332 in onnxruntime/test/ir/graph_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/ir/graph_test.cc:2332: Lines should be <= 120 characters long [whitespace/line_length] [2]

graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4});

ONNX_NAMESPACE::AttributeProto full_shape_outputs;
const std::string attribute_name = "full_shape_outputs";
full_shape_outputs.set_name(attribute_name);
full_shape_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS);
full_shape_outputs.add_ints(static_cast<int64_t>(0));
NodeAttributes attributes({{attribute_name, full_shape_outputs}});

graph.AddNode("node_yield", "YieldOp", "node yield", {&output_arg4}, {&output_arg5}, &attributes, kMSDomain);
graph.AddNode("node_1_grad", "Merge_Fake", "node_1 gradient", {&output_arg5, &output_arg2}, {&output_arg6});

auto status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
GraphViewer graph_viewer(graph);

// MEMORY_EFFICIENT order
{
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::MEMORY_EFFICIENT);
const std::vector<std::string> expected_order =
{

Check warning on line 2354 in onnxruntime/test/ir/graph_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/test/ir/graph_test.cc:2354: { should almost always be at the end of the previous line [whitespace/braces] [4]
"node_0",
"node_1",
"node_4",
"node_yield",
"recompute_node_1",
"node_1_grad",
"graph_output_0_identity",
"graph_output_1_identity",
};
for (size_t i = 0; i < order.size(); ++i) {
auto node = graph.GetNode(order[i]);
EXPECT_TRUE(node->Name() == expected_order[i])
<< "MEMORY_EFFICIENT based execution order is wrong. expected node is " << expected_order[i]
<< " but got " << node->Name();
}
}
}

#endif

} // namespace test
Expand Down
Loading