-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DML EP] DML Graph Serialization Bug (#19748)
### Description This pull request addresses several issues: - The DML Graph's nodes were not sorted in a topologically ordered sequence, leading to crashes during deserialization when a child node preceded its parent node. This PR resolves this issue by implementing a topological sorting algorithm before serialization. - During the `RemoveUnconnectedNodes` process: - we update `intermeidateEdge.FromNodeIndex`. Additionally, we must update `intermediateEdge.Name` when it includes `intermediateEdge.FromNodeIndex`, as serialization/deserialization heavily relies on edge names. - we also eliminate unused edges. Consequently, we must erase inputs (now unused) from corresponding maps `serializedGraphInputIndexToSubgraphInputIndex` and `serializedGraphLargeConstantNameToSubgraphInputIndex`. ### Motivation and Context Why is this change required? What problem does it solve? There are few ONNX Zoo public models which were crashing during deserialization. <!-- - - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Jeff Bloomfield <[email protected]>
- Loading branch information
Showing
6 changed files
with
151 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 52 additions & 0 deletions
52
...ime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphHelper.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
|
||
#pragma once | ||
#include <queue> | ||
|
||
inline void PerformTopologicalSortAndCheckIsAcyclic( | ||
const DmlSerializedGraphDesc& graphDesc, | ||
std::vector<uint32_t>& nodesInTopologicalOrder) | ||
{ | ||
uint32_t nodeCount = static_cast<uint32_t>(graphDesc.Nodes.size()); | ||
std::queue<uint32_t> queue; | ||
std::vector<uint32_t> inDegree(nodeCount, 0); | ||
std::vector<std::vector<uint32_t>> children(nodeCount); | ||
|
||
// Don't need to iterate through InputEdges because those inputs don't represent any node | ||
// and the purpose of this topological sort is to come up with a order to correctly iterate | ||
// through nodes . | ||
for (const DmlIntermediateSerializedGraphEdge& intermediateEdge : graphDesc.IntermediateEdges) | ||
{ | ||
inDegree[intermediateEdge.ToNodeIndex]++; | ||
children[intermediateEdge.FromNodeIndex].push_back(intermediateEdge.ToNodeIndex); | ||
} | ||
|
||
for (uint32_t nodeIndex = 0; nodeIndex < nodeCount; nodeIndex++) | ||
{ | ||
if (inDegree[nodeIndex] == 0) | ||
{ | ||
queue.push(nodeIndex); | ||
} | ||
} | ||
|
||
uint32_t nodeIndex = 0; | ||
while (!queue.empty()) | ||
{ | ||
if (nodeIndex >= nodeCount) | ||
{ | ||
throw std::invalid_argument("Given graph is not acyclic."); | ||
} | ||
|
||
uint32_t currNodeIndex = queue.front(); | ||
queue.pop(); | ||
nodesInTopologicalOrder[nodeIndex++] = currNodeIndex; | ||
|
||
for (uint32_t child : children[currNodeIndex]) | ||
{ | ||
if (--inDegree[child] == 0) | ||
{ | ||
queue.push(child); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters