Skip to content

Commit

Permalink
[DML EP] DML Graph Serialization Bug (#19748)
Browse files Browse the repository at this point in the history
### Description
This pull request addresses several issues:

- The DML Graph's nodes were not sorted in a topologically ordered
sequence, leading to crashes during deserialization when a child node
preceded its parent node. This PR resolves this issue by implementing a
topological sorting algorithm before serialization.

- During the `RemoveUnconnectedNodes` process:
- we update `intermeidateEdge.FromNodeIndex`. Additionally, we must
update `intermediateEdge.Name` when it includes
`intermediateEdge.FromNodeIndex`, as serialization/deserialization
heavily relies on edge names.

- we also eliminate unused edges. Consequently, we must erase inputs
(now unused) from corresponding maps
`serializedGraphInputIndexToSubgraphInputIndex` and
`serializedGraphLargeConstantNameToSubgraphInputIndex`.


### Motivation and Context
Why is this change required? What problem does it solve?
There are few ONNX Zoo public models which were crashing during
deserialization.
<!-- - - If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Jeff Bloomfield <[email protected]>
  • Loading branch information
sumitsays and jeffbloo authored Mar 31, 2024
1 parent a0ebd5f commit e1e292f
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ template <typename EdgeType> void PopulateEdges(
if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end())
{
throw std::range_error("Neither there is any graph input with name " + edgeName->str() +
"nor there is any node which has " + edgeName->str() + " as one of the output.");
" nor there is any node which has " + edgeName->str() + " as one of the output.");
}
auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()];
DmlIntermediateSerializedGraphEdge intermediateEdge = {};
Expand Down Expand Up @@ -475,14 +475,15 @@ DmlSerializedGraphDesc DeserializeDmlGraph(
inputEdges,
intermediateEdges,
edgeToOutgoingNodeIndexMap);

PopulateEdges<DmlOutputSerializedGraphEdge>(
nodeIndex,
flatbufferNode->outputNames(),
graphOutputEdgeToIndexMap,
outputEdges,
intermediateEdges,
edgeToOutgoingNodeIndexMap);

DmlSerializedGraphNode node = {};
if (flatbufferNode->name()->size() == 0)
{
Expand All @@ -503,7 +504,7 @@ DmlSerializedGraphDesc DeserializeDmlGraph(

ConstantName constantNode = {flatbufferConstantNode->data_as_ConstantName()->name()->c_str()};
node.Desc = constantNode;
// output of this node will part of constantInputs list
// Output of this node will be part of constantInputs list.
for (uint32_t outputIndex = 0; outputIndex < flatbufferNode->outputNames()->size(); outputIndex++)
{
constantInputs.insert(flatbufferNode->outputNames()->Get(outputIndex)->c_str());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,37 +596,37 @@ namespace DmlGraphFusionHelper
const std::unordered_map<uint32_t, uint32_t>* serializedGraphInputIndexToSubgraphInputIndex,
const std::unordered_map<std::string_view, uint32_t>* serializedGraphLargeConstantNameToSubgraphInputIndex)
{
if (graphSerializationEnabled)
{

const std::wstring modelName = GetModelName(graph.ModelPath());
auto buffer = SerializeDmlGraph(graphDesc);

const std::wstring partitionName =
L"Partition_" +
std::to_wstring(partitionIndex) +
L".bin";
WriteToFile(modelName, partitionName, buffer.data(), buffer.size());

std::vector<std::unique_ptr<std::byte[]>> rawData;
DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData);
GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {};
deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount;
deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges);
deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges);
deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes);
deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount;
deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges);
deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList;
deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes;

compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator(
deserializedDmlGraphDesc,
indexedSubGraph,
providerImpl,
serializedGraphInputIndexToSubgraphInputIndex,
serializedGraphLargeConstantNameToSubgraphInputIndex);
}
if (graphSerializationEnabled)
{
const std::wstring modelName = GetModelName(graph.ModelPath());
auto buffer = SerializeDmlGraph(graphDesc);
const std::wstring partitionName =
L"Partition_" +
std::to_wstring(partitionIndex) +
L".bin";
WriteToFile(modelName, partitionName, buffer.data(), buffer.size());
std::vector<std::unique_ptr<std::byte[]>> rawData;
DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData);
GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {};
deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount;
deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges);
deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges);
deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes);
deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount;
deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges);
deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList;
deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes;
compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator(
deserializedDmlGraphDesc,
indexedSubGraph,
providerImpl,
serializedGraphInputIndexToSubgraphInputIndex,
serializedGraphLargeConstantNameToSubgraphInputIndex);
}

auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name);
fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& grap
return builder.Release();
}

std::vector<uint32_t> nodesInTopologicalOrder(graphDesc.Nodes.size());
PerformTopologicalSortAndCheckIsAcyclic(graphDesc, nodesInTopologicalOrder);

// create input/output edge index to name map
std::unordered_map<uint32_t, flatbuffers::Offset<flatbuffers::String>> graphInputIndexToNameMap =
ConvertToEdgeIndexToNameMap<DmlInputSerializedGraphEdge>(graphDesc.InputEdges, builder);
Expand Down Expand Up @@ -548,14 +551,14 @@ flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& grap

// Create flatbuffer node objects
std::vector<flatbuffers::Offset<dml::ir::DmlGraphNode>> nodes(graphDesc.Nodes.size());
for (uint32_t nodeIndex = 0; nodeIndex < static_cast<uint32_t>(graphDesc.Nodes.size()); nodeIndex++)
for (uint32_t nodeIndex = 0; nodeIndex < static_cast<uint32_t>(nodesInTopologicalOrder.size()); nodeIndex++)
{
nodes[nodeIndex] = SerializeNode(
builder,
nodeIndex,
graphDesc.Nodes[nodeIndex],
nodeToInputNames[nodeIndex],
nodeToOutputNames[nodeIndex]);
nodesInTopologicalOrder[nodeIndex],
graphDesc.Nodes[nodesInTopologicalOrder[nodeIndex]],
nodeToInputNames[nodesInTopologicalOrder[nodeIndex]],
nodeToOutputNames[nodesInTopologicalOrder[nodeIndex]]);
}

// Convert to std::vector to create the <dml::ir::DmlGraphDesc> object.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

#pragma once
#include <queue>

inline void PerformTopologicalSortAndCheckIsAcyclic(
const DmlSerializedGraphDesc& graphDesc,
std::vector<uint32_t>& nodesInTopologicalOrder)
{
uint32_t nodeCount = static_cast<uint32_t>(graphDesc.Nodes.size());
std::queue<uint32_t> queue;
std::vector<uint32_t> inDegree(nodeCount, 0);
std::vector<std::vector<uint32_t>> children(nodeCount);

// Don't need to iterate through InputEdges because those inputs don't represent any node
// and the purpose of this topological sort is to come up with a order to correctly iterate
// through nodes .
for (const DmlIntermediateSerializedGraphEdge& intermediateEdge : graphDesc.IntermediateEdges)
{
inDegree[intermediateEdge.ToNodeIndex]++;
children[intermediateEdge.FromNodeIndex].push_back(intermediateEdge.ToNodeIndex);
}

for (uint32_t nodeIndex = 0; nodeIndex < nodeCount; nodeIndex++)
{
if (inDegree[nodeIndex] == 0)
{
queue.push(nodeIndex);
}
}

uint32_t nodeIndex = 0;
while (!queue.empty())
{
if (nodeIndex >= nodeCount)
{
throw std::invalid_argument("Given graph is not acyclic.");
}

uint32_t currNodeIndex = queue.front();
queue.pop();
nodesInTopologicalOrder[nodeIndex++] = currNodeIndex;

for (uint32_t child : children[currNodeIndex])
{
if (--inDegree[child] == 0)
{
queue.push(child);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ namespace Dml::GraphDescBuilder
std::vector<DmlSerializedGraphNode>& graphNodes,
std::vector<DmlInputSerializedGraphEdge>& graphInputEdges,
std::vector<DmlIntermediateSerializedGraphEdge>& graphIntermediateEdges,
std::vector<DmlOutputSerializedGraphEdge>& graphOutputEdges)
std::vector<DmlOutputSerializedGraphEdge>& graphOutputEdges,
std::unordered_map<uint32_t, uint32_t>& serializedGraphInputIndexToSubgraphInputIndex,
std::unordered_map<std::string_view, uint32_t>& serializedGraphLargeConstantNameToSubgraphInputIndex)
{
enum class NodeState
{
Expand Down Expand Up @@ -124,8 +126,10 @@ namespace Dml::GraphDescBuilder
graphNodes.resize(graphNodes.size() - shift);

// Adjust the node indices in the input edges
std::unordered_set<uint32_t> usedInputEdgeIndex;
for (auto& inputEdge : graphInputEdges)
{
usedInputEdgeIndex.insert(inputEdge.GraphInputIndex);
inputEdge.ToNodeIndex = shiftedIndicesMapping[inputEdge.ToNodeIndex];
}

Expand All @@ -136,10 +140,54 @@ namespace Dml::GraphDescBuilder
}

// Adjust the node indices in the intermediate edges
std::unordered_set<std::string> usedLargeConstantNames;
for (auto& intermediateEdge : graphIntermediateEdges)
{
intermediateEdge.FromNodeIndex = shiftedIndicesMapping[intermediateEdge.FromNodeIndex];
intermediateEdge.ToNodeIndex = shiftedIndicesMapping[intermediateEdge.ToNodeIndex];
// We need to update the edge name only when the name contains the intermediateEdge.FromNodeIndex
size_t pos = intermediateEdge.Name.find("nodeIdx:");
if (pos != std::string::npos)
{
if (pos != 0)
{
std::string constantNamePartComingFromModel = intermediateEdge.Name.substr(0, pos - 1);
usedLargeConstantNames.insert(constantNamePartComingFromModel); // need part of name which is coming from the model.
intermediateEdge.Name = constantNamePartComingFromModel;
intermediateEdge.Name += "-nodeIdx:" + std::to_string(intermediateEdge.FromNodeIndex) + "-outputIdx:" + std::to_string(intermediateEdge.FromNodeOutputIndex);
}
else
{
intermediateEdge.Name = "nodeIdx:" + std::to_string(intermediateEdge.FromNodeIndex) + "-outputIdx:" + std::to_string(intermediateEdge.FromNodeOutputIndex);
}
}
}


// Erase the mapping if the input Edge is not used by any node
for (auto it = serializedGraphInputIndexToSubgraphInputIndex.begin(); it != serializedGraphInputIndexToSubgraphInputIndex.end();)
{
if (!usedInputEdgeIndex.count(it->first))
{
it = serializedGraphInputIndexToSubgraphInputIndex.erase(it);
}
else
{
it++;
}
}

// Erase the mapping if the input Edge is not used by any node
for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end();)
{
if (!usedLargeConstantNames.count(std::string(it->first)))
{
it = serializedGraphLargeConstantNameToSubgraphInputIndex.erase(it);
}
else
{
it++;
}
}
}

Expand Down Expand Up @@ -516,7 +564,12 @@ namespace Dml::GraphDescBuilder
graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex);
}

RemoveUnconnectedNodes(dmlGraphNodes, dmlGraphInputEdges, dmlGraphIntermediateEdges, dmlGraphOutputEdges);
RemoveUnconnectedNodes(dmlGraphNodes,
dmlGraphInputEdges,
dmlGraphIntermediateEdges,
dmlGraphOutputEdges,
serializedGraphInputIndexToSubgraphInputIndex,
serializedGraphLargeConstantNameToSubgraphInputIndex);

GraphDesc graphDesc{};
graphDesc.InputCount = static_cast<uint32_t>(dmlGraphInputEdges.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
#include "External/DirectMLHelpers/DmlSerializedGraphDesc.h"
#include "External/DirectMLHelpers/DmlGraphSerialization.h"
#include "External/DirectMLHelpers/DmlGraphDeserialization.h"
#include "External/DirectMLHelpers/DmlGraphHelper.h"

using Microsoft::WRL::ComPtr;

Expand Down

0 comments on commit e1e292f

Please sign in to comment.