Skip to content

Commit

Permalink
Bug fix for nested control flow ops for TRT EP (#16343)
Browse files Browse the repository at this point in the history
Current TRT EP can support model which has nested control flow ops
(multiple level subgraphs). But it fails at a case where the subgraph
has outer scope value that is defined several levels up in the top-level
graph, in this case, the outer scope value is the input of the top-level
graph. The outer scope values are not properly handled during TRT EP's
subgraph reconstruction stage and fails at `graph.resolve()`.

The way ORT gets capability from EPs is a bottom-up approach meaning
inner most subgraph gets handled first. TRT EP reconstructs each
subgraph level by level and following modifications are made to fix the
outer scope values issue:

- `SetGraphOuterScopeValuesAndInputs()` and `SetAllGraphInputs()` are
added to handle outer scope values and add those values as graph inputs
if needed in order to make `graph.resolve()` happy.
- Change to use `GetNodeArgIncludingParentGraphs` so that when creating
the fused TRT node for some subgraphs in`
Graph::CreateFusedSubGraphNode()`, it can get the NodeArgs for outer
scope values from top-level graph.


This PR fixes #16217
  • Loading branch information
chilo-ms authored and jchen351 committed Aug 12, 2023
1 parent 7f1de89 commit f768269
Show file tree
Hide file tree
Showing 9 changed files with 604 additions and 2 deletions.
8 changes: 6 additions & 2 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3878,13 +3878,17 @@ Node& Graph::CreateFusedSubGraphNode(const IndexedSubGraph& sub_graph, const std

int cur_idx = 0;
for (const auto& arg_name : func_meta_def->inputs) {
input_args.push_back(GetNodeArg(arg_name));
// In some cases, it needs to get the NodeArgs from ancestors.
// For example, if the subgraph we are going to build is the subgraph of the original graph
// and the NodeArgs of the outer scope values are defined in the top-level original graph.
input_args.push_back(GetNodeArgIncludingParentGraphs(arg_name));
input_indexes[arg_name] = cur_idx++;
}

cur_idx = 0;
for (const auto& arg_name : func_meta_def->outputs) {
output_args.push_back(GetNodeArg(arg_name));
// In some cases, it needs to get the NodeArgs from ancestors.
output_args.push_back(GetNodeArgIncludingParentGraphs(arg_name));
output_indexes[arg_name] = cur_idx++;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,8 @@ struct ProviderHost {
virtual std::unique_ptr<Node__EdgeIterator> Node__OutputEdgesEnd(const Node* p) noexcept = 0;

virtual void Node__ForEachDef(const Node* p, std::function<void(const NodeArg&, bool is_input)> func, bool include_missing_optional_defs) = 0;
virtual const std::unordered_map<std::string, gsl::not_null<Graph*>>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) = 0;
virtual std::unordered_map<std::string, gsl::not_null<const Graph*>> Node__GetAttributeNameToSubgraphMap(const Node* p) const = 0;

// NodeArg
virtual const std::string& NodeArg__Name(const NodeArg* p) noexcept = 0;
Expand Down Expand Up @@ -695,6 +697,8 @@ struct ProviderHost {
virtual std::unique_ptr<ONNX_NAMESPACE::GraphProto> Graph__ToGraphProto(const Graph* p) = 0;

virtual NodeArg& Graph__GetOrCreateNodeArg(Graph* p, const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) = 0;
virtual void Graph__AddOuterScopeNodeArg(Graph* p, const std::string& name) = 0;
virtual void Graph__SetInputs(Graph* p, gsl::span<const NodeArg* const> inputs) = 0;

virtual Status Graph__Resolve(Graph* p) = 0;
virtual void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) = 0;
Expand All @@ -708,10 +712,15 @@ struct ProviderHost {

virtual const Node* Graph__ParentNode(const Graph* p) const = 0;
virtual const Graph* Graph__ParentGraph(const Graph* p) const = 0;
virtual Graph* Graph__MutableParentGraph(Graph* p) = 0;
virtual const std::string& Graph__Name(const Graph* p) const noexcept = 0;
virtual const Path& Graph__ModelPath(const Graph* p) const = 0;
virtual const std::vector<const NodeArg*>& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0;
virtual bool Graph__IsSubgraph(const Graph* p) = 0;
virtual int Graph__MaxNodeIndex(const Graph* p) const noexcept = 0;
virtual Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept = 0;
virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0;
virtual const NodeArg* Graph__GetNodeArg(const Graph* p, const std::string& name) const = 0;

// GraphViewer
virtual void GraphViewer__operator_delete(GraphViewer* p) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,8 @@ struct Node final {
EdgeConstIterator OutputEdgesEnd() const noexcept { return g_host->Node__OutputEdgesEnd(this); }

void ForEachDef(std::function<void(const NodeArg&, bool is_input)> func, bool include_missing_optional_defs = false) const { g_host->Node__ForEachDef(this, func, std::move(include_missing_optional_defs)); }
const std::unordered_map<std::string, gsl::not_null<Graph*>>& GetAttributeNameToMutableSubgraphMap() { return g_host->Node__GetAttributeNameToMutableSubgraphMap(this); }
std::unordered_map<std::string, gsl::not_null<const Graph*>> GetAttributeNameToSubgraphMap() const { return g_host->Node__GetAttributeNameToSubgraphMap(this); }

PROVIDER_DISALLOW_ALL(Node)
};
Expand Down Expand Up @@ -707,6 +709,8 @@ struct Graph final {
std::unique_ptr<ONNX_NAMESPACE::GraphProto> ToGraphProto() const { return g_host->Graph__ToGraphProto(this); }

NodeArg& GetOrCreateNodeArg(const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) { return g_host->Graph__GetOrCreateNodeArg(this, name, p_arg_type); }
void AddOuterScopeNodeArg(const std::string& name) { g_host->Graph__AddOuterScopeNodeArg(this, name); }
void SetInputs(gsl::span<const NodeArg* const> inputs) { g_host->Graph__SetInputs(this, inputs); }

Status Resolve() { return g_host->Graph__Resolve(this); }
void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor) { return g_host->Graph__AddInitializedTensor(this, tensor); }
Expand All @@ -721,10 +725,15 @@ struct Graph final {

const Node* ParentNode() const { return g_host->Graph__ParentNode(this); }
const Graph* ParentGraph() const { return g_host->Graph__ParentGraph(this); }
Graph* MutableParentGraph() { return g_host->Graph__MutableParentGraph(this); }
const std::string& Name() const noexcept { return g_host->Graph__Name(this); }
const Path& ModelPath() const { return g_host->Graph__ModelPath(this); }
const std::vector<const NodeArg*>& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); }
bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); }
int MaxNodeIndex() const noexcept { return g_host->Graph__MaxNodeIndex(this); }
const Node* GetNode(NodeIndex node_index) const noexcept { return g_host->Graph__GetNode(this, node_index); }
Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); }
const NodeArg* GetNodeArg(const std::string& name) const { return g_host->Graph__GetNodeArg(this, name); }

PROVIDER_DISALLOW_ALL(Graph)
};
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
} else {
auto model_build = graph.CreateModel(*GetLogger());
auto& graph_build = model_build->MainGraph();
bool has_control_flow_op = false;

// Add node and node args
// If node output is also parent graph output, the output will be added to the
Expand Down Expand Up @@ -1321,6 +1322,10 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
}
}

if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) {
has_control_flow_op = true;
}

// If the node has subgraph, it's possible that the ORT graph of that subgraph and the GraphProto in the node attributes are not in sync because of graph optimization.
// Therefore, we need to force GraphProto attributes to be updated in order to get the valid GraphProto.
if (node->GetAttributes().size() > 0) {
Expand All @@ -1345,6 +1350,13 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
}
}

if (has_control_flow_op) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name();
BuildSubGraphContext(graph_build);
SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph());
SetAllGraphInputs(graph_build);
}

ORT_ENFORCE(graph_build.Resolve().IsOK());

// Add parent graph output to the subgraph
Expand Down Expand Up @@ -1657,6 +1669,20 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0);
SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}};
bool subgraph_early_termination = false;

// Another subgraph of "If" control flow has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP.
if (AllNodesAssignedToSpecificEP(*sub_graph_veiwer, kTensorrtExecutionProvider)) {
all_subgraphs_are_supported = true;
break;
}
// Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP.
// (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs)
else if (!AllNodesAssignedToSpecificEP(*sub_graph_veiwer, "")) {
all_subgraphs_are_supported = false;
break;
}

// Another subgraph of "If" control flow has not yet been parsed by GetCapability.
subgraph_supported_nodes_vector = GetSupportedList(parser_subgraph_nodes_vector, 0, max_partition_iterations_, *sub_graph_veiwer, &subgraph_early_termination);
all_subgraphs_are_supported = IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes);
break;
Expand All @@ -1677,6 +1703,9 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
}
}
LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";

// The context map is only used during EP compile time, release it to save memory space.
subgraph_context_map_.clear();
return result;
}
}
Expand All @@ -1700,6 +1729,8 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
}

// The context map is only used during EP compile time, release it to save memory space.
subgraph_context_map_.clear();
return result;
}

Expand Down
46 changes: 46 additions & 0 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ struct TensorrtFuncState {
bool cuda_graph_enable = 0;
};

// Holds important information for building valid ORT graph.
struct SubGraphContext {
std::unordered_set<std::string> output_args;
std::unordered_map<std::string, const NodeArg*> inputs_and_initializers;
std::unordered_map<std::string, const NodeArg*> manually_added_graph_inputs;
};

using SubGraphContextMap = std::unordered_map<std::string, std::unique_ptr<SubGraphContext>>;

// Logical device representation.
class TensorrtExecutionProvider : public IExecutionProvider {
public:
Expand Down Expand Up @@ -224,6 +233,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
mutable std::unordered_map<std::string, std::unique_ptr<SubGraphContext>> subgraph_context_map_;
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvonnxparser::IParser>> parsers_;
std::unordered_map<std::string, std::unique_ptr<nvinfer1::ICudaEngine>> engines_;
std::unordered_map<std::string, std::unique_ptr<nvinfer1::IExecutionContext>> contexts_;
Expand Down Expand Up @@ -273,6 +283,42 @@ class TensorrtExecutionProvider : public IExecutionProvider {
/**Check whether all the nodes of subgraph are supported*/
bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const;

/**
* Set inputs, initializers and outputs for all subgraphs during TensorrtExecutionProvider::GetSupportedList()
* and save those information in subgraph context data structure. It's useful for building a valid graph and
* make Graph::Resolve() happy especially when dealing with nested control-flow op graph.
*/
void BuildSubGraphContext(const Graph& build_graph) const;

/**
* Set outer scope values for subgraphs and add thoes values as top-level graph's inputs if needed.
*/
void SetGraphOuterScopeValuesAndInputs(Graph& build_graph, const Graph& graph) const;

/**
* If ORT TRT manually sets graph input in TensorrtExecutionProvider::SetGraphOuterScopeValuesAndInputs(),
* we have to manully set all the graph inputs in order to pass Graph::Resolve().
*/
void SetAllGraphInputs(Graph& graph) const;

/**
* The newly-built graph has not yet being resolved by Graph::Resolve(), so we can't leverage
* Graph::ResolveContext::IsInputInitializerOrOutput(). We have to implement this fuction again.
*/
bool IsInputInitializerOrOutput(const Graph& graph, const std::string& name, bool check_ancestors) const;

/**
* The newly-built graph has not yet being resolved by Graph::Resolve(), so we can't leverage
* Graph::ResolveContext::IsOuterScopeValue(). We have to implement this fuction again.
*/
bool IsOuterScopeValue(const Graph& graph, const std::string& name) const;

/**
* The newly-built graph has not yet being resolved by Graph::Resolve(), so we can't leverage
* Graph::ResolveContext::IsLocalValue(). We have to implement this fuction again.
*/
bool IsLocalValue(const Graph& graph, const std::string& name) const;

bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
Expand Down
Loading

0 comments on commit f768269

Please sign in to comment.