Skip to content

Commit

Permalink
Fix missing node during mem efficient topo sort (#20497)
Browse files Browse the repository at this point in the history
### Fix missing node during mem efficient topo sort

Some nodes are not cusumed by the backward path, they are also not
generating graph outputs. We missed those nodes, so this PR fix that and
add related tests.

A side note: we should remove those nodes that are not used for
computing any graph outputs in a graph transformer. (TODO)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored May 6, 2024
1 parent a366920 commit addcc4c
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 9 deletions.
60 changes: 51 additions & 9 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,14 @@ struct GroupNode {
intermediate_args.insert(arg);
}

if (node->GetOutputEdgesCount() == 0) {
for (const NodeArg* arg : node->OutputDefs()) {
output_args.push_back(arg);
}

continue;
}

for (auto output_edge_it = node->OutputEdgesBegin(); output_edge_it != node->OutputEdgesEnd();
++output_edge_it) {
const Node* output_node = &output_edge_it->GetNode();
Expand Down Expand Up @@ -2012,7 +2020,8 @@ void FindBranchGraph(
const InlinedVector<const Node*>& branch_graph_input_nodes,
const InlinedVector<size_t>& backward_node_in_degree,
InlinedVector<const Node*>& branch_graph,
InlinedVector<std::pair<const Node*, size_t>>& branch_subgraph_consumers) {
InlinedVector<std::pair<const Node*, size_t>>& branch_subgraph_consumers,
InlinedVector<const NodeArg*>& branch_subgraph_outputs) {
// Loop through the branch_graph_input_nodes to find the branch subgraphs by its output edges in BFS,
// and find the maximum self_contained subgraph taking the branch_graph_input_nodes as input nodes.
std::queue<const Node*> to_visit_queue;
Expand Down Expand Up @@ -2044,30 +2053,41 @@ void FindBranchGraph(
// At this point, branch_graph is a big subgraph that contains all the nodes that are purely
// triggered by the branch_graph_input_nodes, other graph input/initializers and leaf nodes (for example Constant).
for (const Node* n : branch_graph) {
if (n->GetOutputEdgesCount() == 0) {
// In case the node connect to graph outputs or nothings, append all outputs as the branch subgraph outputs.
for (auto output_def : n->OutputDefs()) {
branch_subgraph_outputs.push_back(output_def);
}
continue;
}

for (auto output_it = n->OutputEdgesBegin(); output_it != n->OutputEdgesEnd(); ++output_it) {
const Node* output_node = &output_it->GetNode();
const size_t dest_in_port = output_it->GetDstArgIndex();
if (std::find(branch_graph.begin(), branch_graph.end(), output_node) == branch_graph.end()) {
branch_subgraph_consumers.push_back({output_node, dest_in_port});
branch_subgraph_outputs.push_back(n->OutputDefs()[output_it->GetSrcArgIndex()]);
}
}
}
}

void TagNodeToAssociatedOutputs(const Graph* graph,
const InlinedHashSet<const Node*>& nodes_to_execute_before_yieldop,
const InlinedVector<std::pair<const Node*, size_t>>& branch_subgraph_consumers,
const InlinedVector<const NodeArg*>& branch_subgraph_outputs,
const InlinedVector<const Node*>& branch_graph,
InlinedVector<GroupNode>& group_node_collection,
InlinedHashMap<const NodeArg*, GroupNode*>& output_arg_to_grouped_node) {
// Reverse DFS from branch graph outputs (e.g. branch_subgraph_consumers) to tag each nodes:
// Reverse DFS from branch graph outputs (e.g. branch_subgraph_outputs) to tag each nodes:
// If one node N contributes to a graph output A, then we will tag A to N.
// If the node N contributes to multiple graph outputs A, B, C, then we will tag the A, B, C to N.
InlinedHashMap<const Node*, std::set<const NodeArg*>> node_to_its_associated_outputs;
node_to_its_associated_outputs.reserve(branch_graph.size());
for (const auto& consumer : branch_subgraph_consumers) {
const NodeArg* output_arg = consumer.first->InputDefs()[consumer.second];
InlinedHashSet<const Node*> handled_branch_subgraph_end_nodes;
for (const auto& output_arg : branch_subgraph_outputs) {
const Node* end_node = graph->GetProducerNode(output_arg->Name());
handled_branch_subgraph_end_nodes.insert(end_node);

InlinedVector<const Node*> end_nodes{end_node};
graph->ReverseDFSFrom(
end_nodes,
Expand Down Expand Up @@ -2097,6 +2117,7 @@ void TagNodeToAssociatedOutputs(const Graph* graph,
group_node_collection.reserve(associated_outputs_to_nodes.size());
for (auto& [associated_outputs, nodes] : associated_outputs_to_nodes) {
group_node_collection.push_back(nodes);

// Flatten the key into NodeArg* for better search.
GroupNode& grouped_node = group_node_collection.back();
for (const auto& output_arg : grouped_node.output_args) {
Expand Down Expand Up @@ -2184,6 +2205,7 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op,

InlinedVector<const Node*> branch_graph_input_nodes;
branch_graph_input_nodes.reserve(num_of_backward_nodes);

PrepareToFindBranchGraph(this,
nodes_to_execute_before_yieldop,
branch_graph_input_nodes,
Expand All @@ -2193,17 +2215,19 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op,
InlinedVector<const Node*> branch_graph;
branch_graph.reserve(num_of_backward_nodes);
InlinedVector<std::pair<const Node*, size_t>> branch_subgraph_consumers;
InlinedVector<const NodeArg*> branch_subgraph_outputs;
FindBranchGraph(branch_graph_input_nodes,
backward_node_in_degree,
branch_graph,
branch_subgraph_consumers);
branch_subgraph_consumers,
branch_subgraph_outputs);

// Cluster the nodes in the branch_graph based on the associated outputs.
InlinedVector<GroupNode> group_node_collection;
InlinedHashMap<const NodeArg*, GroupNode*> output_arg_to_grouped_node;
TagNodeToAssociatedOutputs(this,
nodes_to_execute_before_yieldop,
branch_subgraph_consumers,
branch_subgraph_outputs,
branch_graph,
group_node_collection,
output_arg_to_grouped_node);
Expand Down Expand Up @@ -2247,9 +2271,27 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op,
// For the group nodes that are not outputted, we need to output them.
// Hitting this code path means some nodes are consuming outputs of forward nodes, and their outputs
// are not used by main branch backward nodes.
for (const auto& [output_arg, grouped_node] : output_arg_to_grouped_node) {
InlinedVector<std::pair<const NodeArg*, GroupNode*>>
left_output_arg_to_grouped_node_vector; // To ensure deterministic order.
left_output_arg_to_grouped_node_vector.reserve(output_arg_to_grouped_node.size());
for (auto& [output_arg, grouped_node] : output_arg_to_grouped_node) {
if (!grouped_node->is_outputted) {
OutputGroupedNodes(this, output_arg, output_arg_to_grouped_node, node_orders, topo_order);
left_output_arg_to_grouped_node_vector.push_back({output_arg, grouped_node});
}
}

if (!left_output_arg_to_grouped_node_vector.empty()) {
// Sort to ensure deterministic order.
std::sort(left_output_arg_to_grouped_node_vector.begin(), left_output_arg_to_grouped_node_vector.end(),
[](const std::pair<const NodeArg*, GroupNode*>& a, const std::pair<const NodeArg*, GroupNode*>& b) {
return a.first->Name() < b.first->Name();
});
for (const auto& pair : left_output_arg_to_grouped_node_vector) {
const NodeArg* output_arg = pair.first;
GroupNode* grouped_node = pair.second;
if (!grouped_node->is_outputted) {
OutputGroupedNodes(this, output_arg, output_arg_to_grouped_node, node_orders, topo_order);
}
}
}

Expand Down
79 changes: 79 additions & 0 deletions onnxruntime/test/ir/graph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2291,6 +2291,85 @@ TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_MultiLayerRec
}
}

TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_SubgraphGeneratingNodeHavingNoConsumers) {
Model model("graph_1", false, *logger_);
auto& graph = model.MainGraph();

/*
|
node_0 (Identity)
/ \ \
node_1 (Identity) \ Identity
| | \_____graph_output_0
node_4 (Identity) |
| |
YieldOp recompute_node_1
\ / \
node_1_grad (Merge) Identity
| |
graph_output_1
*/

TypeProto tensor_int32;
tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);

auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32);
auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32);
auto& graph_output_0_identity = graph.GetOrCreateNodeArg("graphoutput_0_identity_out_1", &tensor_int32);
auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32);
auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32);
auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32);
auto& output_arg5 = graph.GetOrCreateNodeArg("node_yield_out_1", &tensor_int32);
auto& output_arg6 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32);

graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0});
graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1});
graph.AddNode("graph_output_0_identity", "Identity_Fake", "graph output 0 identity", {&output_arg0}, {&graph_output_0_identity});
graph.AddNode("recompute_node_1", "Identity_Fake", "recompute node 1", {&output_arg0}, {&output_arg2});

auto& graph_output1_identity = graph.GetOrCreateNodeArg("graphoutput_1_identity_out_1", &tensor_int32);
graph.AddNode("graph_output_1_identity", "Identity_Fake", "graph output 1 identity", {&output_arg2}, {&graph_output1_identity});

graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4});

ONNX_NAMESPACE::AttributeProto full_shape_outputs;
const std::string attribute_name = "full_shape_outputs";
full_shape_outputs.set_name(attribute_name);
full_shape_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS);
full_shape_outputs.add_ints(static_cast<int64_t>(0));
NodeAttributes attributes({{attribute_name, full_shape_outputs}});

graph.AddNode("node_yield", "YieldOp", "node yield", {&output_arg4}, {&output_arg5}, &attributes, kMSDomain);
graph.AddNode("node_1_grad", "Merge_Fake", "node_1 gradient", {&output_arg5, &output_arg2}, {&output_arg6});

auto status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
GraphViewer graph_viewer(graph);

// MEMORY_EFFICIENT order
{
auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::MEMORY_EFFICIENT);
const std::vector<std::string> expected_order =
{
"node_0",
"node_1",
"node_4",
"node_yield",
"recompute_node_1",
"node_1_grad",
"graph_output_0_identity",
"graph_output_1_identity",
};
for (size_t i = 0; i < order.size(); ++i) {
auto node = graph.GetNode(order[i]);
EXPECT_TRUE(node->Name() == expected_order[i])
<< "MEMORY_EFFICIENT based execution order is wrong. expected node is " << expected_order[i]
<< " but got " << node->Name();
}
}
}

#endif

} // namespace test
Expand Down

0 comments on commit addcc4c

Please sign in to comment.