Skip to content

Commit

Permalink
Address subgraph node reserrection.
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Oct 27, 2023
1 parent f074431 commit 9110692
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 19 deletions.
8 changes: 8 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,14 @@ class Graph {
*/
bool RemoveNode(NodeIndex node_index);

/** Remove a Node from this Graph and free it. The function in addition
removes a corresponding NodeProto in the graph_proto if there is a match.
The function calls RemoveNode() internally, and all the requirements for RemoveNode()
must be met. In addition, the node must not produce any Graph outputs.
@returns true if the node_index was valid
*/
bool RemoveNodeAndProto(NodeIndex node_index);

/** Add an edge between two Nodes.
@param src_node_index NodeIndex of source Node that is providing output to the destination Node.
@param dst_node_index NodeIndex of destination Node that is receiving input from the source Node.
Expand Down
59 changes: 45 additions & 14 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1308,8 +1308,17 @@ Graph::Graph(const Model& owning_model,
}
}

for (const auto& node_proto : graph_proto_->node()) {
AddNode(node_proto, name_to_type_map);
{
static const std::string node_base_name{"proto_based_node"};
auto& nodes = *graph_proto_->mutable_node();
for (auto it = nodes.begin(), end = nodes.end(); it != end; ++it) {
// generate a name so the node_proto can be found by name in case it
// needs to be removed.
if (it->name().empty()) {
it->set_name(GenerateNodeName(node_base_name));
}
AddNode(*it, name_to_type_map);
}
}

if (is_loaded_from_model_file_) {
Expand Down Expand Up @@ -3303,6 +3312,36 @@ bool Graph::RemoveNode(NodeIndex p_index) {

return ReleaseNode(p_index);
}

bool Graph::RemoveNodeAndProto(NodeIndex node_index) {
auto* node = GetNode(node_index);
if (nullptr == node) {
return false;
}

const std::string node_name = node->Name();
bool result = RemoveNode(node_index);

if (result && !node_name.empty()) {
// Remove node's proto from graph_proto to prevent constant folded Node re-creation
// in subgraphs.
// E.g. when a node is constant folded and removed from a subgraph, but then re-created
// because a parent node that owns the subgraph being copied/inlined. The constant folded Node
// is then re-created based on the subgraph proto. This creates a duplicate name and invalidates the subgraph.
// This fixes the asymmetry of adding a new initializer to the graph proto, but not removing
// the node's proto with a node_arg with a name of the initializer.
auto& node_list = *graph_proto_->mutable_node();
for (auto it = node_list.begin(), end = node_list.end(); it != end; ++it) {
if (it->name() == node_name) {
ORT_IGNORE_RETURN_VALUE(node_list.erase(it));
break;
}
}
}

return result;
}

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -4046,7 +4085,7 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod
return Status::OK();
}

static void RenameSubgraphDependentNames(const InlinedHashMap<std::string, std::string>& name_mapping,
static void RenameSubgraphDependentNames(const std::unordered_map<std::string, std::string>& name_mapping,
ONNX_NAMESPACE::GraphProto& graph_proto) {
for (auto& node : *graph_proto.mutable_node()) {
for (auto& attr : *node.mutable_attribute()) {
Expand All @@ -4065,18 +4104,10 @@ static void RenameSubgraphDependentNames(const InlinedHashMap<std::string, std::
input = hit->second;
}
}

// XXX: Can output names depend on the outer scope?
// for (auto& output : *node.mutable_output()) {
// auto hit = name_mapping.find(output);
// if (hit != name_mapping.cend()) {
// output = hit->second;
// }
//}
}
}

static void RenameNodeAttributesSubgraphDependentNames(const InlinedHashMap<std::string, std::string>& name_mapping,
static void RenameNodeAttributesSubgraphDependentNames(const std::unordered_map<std::string, std::string>& name_mapping,
NodeAttributes& attributes) {
for (auto& attribute : attributes) {
auto& attr_proto = attribute.second;
Expand All @@ -4102,7 +4133,7 @@ Status Graph::InlineIfSubgraph(const Graph& graph_to_inline, Node& if_node) {

// Check if the name is an input or implicit input.
// These are not renamed.
InlinedHashSet<std::string_view> if_all_inputs;
std::unordered_set<std::string_view> if_all_inputs;
const auto if_inputs = if_node.InputDefs();
const auto if_implicit_inputs = if_node.ImplicitInputDefs();
if_all_inputs.reserve(if_inputs.size() + if_implicit_inputs.size());
Expand All @@ -4118,7 +4149,7 @@ Status Graph::InlineIfSubgraph(const Graph& graph_to_inline, Node& if_node) {

// Name mapping from the graph to inline to the graph we are inlining into
// we also use this to process any subgraphs in the graph we are inlining
InlinedHashMap<std::string, std::string> name_mapping;
std::unordered_map<std::string, std::string> name_mapping;

// We are going to map the outputs of the graph to inline to the outputs of the If node.
// They are assumed to be in the same order.
Expand Down
6 changes: 1 addition & 5 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,6 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
bool folded = false;
ORT_RETURN_IF_ERROR(ConstantFoldIfNode(graph, *node, logger, folded));
if (folded) {
// We do not remove any of the upstream nodes, because we actually
// do not know whether any of the upstream nodes provide constant implicit inputs
// We let the next round of constant folding check of that.
// Remove the output edges of the constant node and then remove the node itself.
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
modified = true;
Expand Down Expand Up @@ -364,7 +360,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,

// Remove the output edges of the constant node and then remove the node itself.
graph_utils::RemoveNodeOutputEdges(graph, *node);
graph.RemoveNode(node->Index());
graph.RemoveNodeAndProto(node->Index());
modified = true;
have_updated_nodes = true;
}
Expand Down

0 comments on commit 9110692

Please sign in to comment.