Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Load precompiled TRT engine file directly #18217

Merged
merged 53 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
f8099b1
update
chilo-ms Oct 31, 2023
ee81de5
fix bug
chilo-ms Nov 1, 2023
452a629
update
chilo-ms Nov 1, 2023
6f18d8d
remove redundant check
chilo-ms Nov 1, 2023
9430768
remove unused variable
chilo-ms Nov 1, 2023
e9507e5
add script to generate epcontext node
chilo-ms Nov 1, 2023
35a4d33
fix bug
chilo-ms Nov 1, 2023
44a7cc5
update
chilo-ms Nov 1, 2023
b2fdb06
update
chilo-ms Nov 2, 2023
eeb6552
refactor
chilo-ms Nov 2, 2023
1b5117d
update
chilo-ms Nov 2, 2023
df7ef46
update
chilo-ms Nov 3, 2023
993b2ad
change function name
chilo-ms Nov 3, 2023
34a86d7
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Nov 3, 2023
9631f73
update
chilo-ms Nov 4, 2023
93f9fbb
update
chilo-ms Nov 5, 2023
d5974fc
refactor
chilo-ms Nov 6, 2023
7202b73
check compute capability
chilo-ms Nov 6, 2023
60f6e7e
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Nov 6, 2023
5838143
update for reading engine byte data
chilo-ms Nov 6, 2023
aea26b1
add script to generate engine wrapper onnx model
chilo-ms Nov 7, 2023
f4b38f7
refactor
chilo-ms Nov 7, 2023
de9f510
fix bug
chilo-ms Nov 7, 2023
65331ed
fix format
chilo-ms Nov 7, 2023
befba02
refactor
chilo-ms Nov 7, 2023
5103dda
add unit test
chilo-ms Nov 7, 2023
f933379
refactor script
chilo-ms Nov 7, 2023
11fd212
fix format
chilo-ms Nov 7, 2023
561b059
update
chilo-ms Nov 9, 2023
789efe8
fix format
chilo-ms Nov 9, 2023
a5843c2
fix format
chilo-ms Nov 9, 2023
11ce3fc
fix gen_trt_engine_wrapper_onnx_model.py
chilo-ms Nov 10, 2023
f9206f3
refactor unit test
chilo-ms Nov 10, 2023
0bd5b8c
update
chilo-ms Nov 10, 2023
e2c3f16
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Nov 19, 2023
6605fd4
update script
chilo-ms Nov 19, 2023
699d538
fix format
chilo-ms Nov 19, 2023
1cbe3e9
fix bug for conflict resolve
chilo-ms Nov 20, 2023
f8e775d
refactor script
chilo-ms Nov 20, 2023
8f7c7ac
generate ep context node model from TRT EP
chilo-ms Nov 23, 2023
77a62f2
add trt_dump_ep_context_model, trt_ep_context_embed_mode, trt_ep_cont…
chilo-ms Nov 23, 2023
04baad7
fix format
chilo-ms Nov 24, 2023
c3a028b
swap the position of CreateNodeComputeFromGraph and CreateNodeCompute…
chilo-ms Jan 8, 2024
db46b64
merge PR 18879 and PR 18834
chilo-ms Jan 9, 2024
842cdf0
merge PR 18879 and PR 18834 (continue)
chilo-ms Jan 9, 2024
ca8d49f
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Jan 9, 2024
e0d3346
remove kernelcontext_setoutput
chilo-ms Jan 9, 2024
f9231a5
fix bugs after merge main
chilo-ms Jan 9, 2024
cebfcd8
apply lintrunner -a
chilo-ms Jan 10, 2024
b1c4305
Use 'hardware_architecture'
chilo-ms Jan 11, 2024
0435971
merge main and also enforce get compute capability once inside ep con…
chilo-ms Jan 11, 2024
77bf077
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Jan 11, 2024
28bdd0a
remove unnecessary code
chilo-ms Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.</dd>
<dt><tt>ep_sdk_version</tt> : string</dt>
<dd>(Optional) SDK version used to convert the model.</dd>
<dt><tt>hardware_architecture</tt> : string</dt>
<dd>(Optional) Hardware architecture.</dd>
<dt><tt>main_context</tt> : int</dt>
<dd>Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.</dd>
<dt><tt>notes</tt> : string</dt>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
int trt_dump_ep_context_model{0}; // Dump EP context node model
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data

Check warning on line 50 in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h#L50

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h:50:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
};
5 changes: 5 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3230,6 +3230,11 @@ void RegisterContribSchemas() {
"(Optional) SDK version used to convert the model.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr(
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
"hardware_architecture",
"(Optional) Hardware architecture.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr(
"partition_name",
"(Optional) partitioned graph name.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ struct ProviderHost {
virtual int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) = 0;
virtual float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) = 0;
virtual void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0;
virtual void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) = 0;
virtual const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) = 0;
virtual void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0;
virtual void AttributeProto__set_type(ONNX_NAMESPACE::AttributeProto* p, ONNX_NAMESPACE::AttributeProto_AttributeType value) = 0;
Expand All @@ -351,6 +352,7 @@ struct ProviderHost {
virtual ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) = 0;
virtual ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) = 0;
virtual ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) = 0;
virtual ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) = 0;

// ModelProto
virtual std::unique_ptr<ONNX_NAMESPACE::ModelProto> ModelProto__construct() = 0;
Expand All @@ -372,6 +374,7 @@ struct ProviderHost {
virtual void NodeProto__operator_assign(ONNX_NAMESPACE::NodeProto* p, const ONNX_NAMESPACE::NodeProto& v) = 0;
virtual int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) = 0;
virtual const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const = 0;
virtual ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) = 0;

// TensorProto
virtual std::unique_ptr<ONNX_NAMESPACE::TensorProto> TensorProto__construct() = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct AttributeProto final {
int64_t i() const { return g_host->AttributeProto__i(this); }
float f() const { return g_host->AttributeProto__f(this); }
void set_s(const ::std::string& value) { return g_host->AttributeProto__set_s(this, value); }
void set_i(int64_t value) { return g_host->AttributeProto__set_i(this, value); }
const ::std::string& s() const { return g_host->AttributeProto__s(this); }
void set_name(const ::std::string& value) { return g_host->AttributeProto__set_name(this, value); }
void set_type(AttributeProto_AttributeType value) { return g_host->AttributeProto__set_type(this, value); }
Expand Down Expand Up @@ -118,6 +119,7 @@ struct GraphProto final {
ValueInfoProtos* mutable_value_info() { return g_host->GraphProto__mutable_value_info(this); }
TensorProtos* mutable_initializer() { return g_host->GraphProto__mutable_initializer(this); }
NodeProto* add_node() { return g_host->GraphProto__add_node(this); }
NodeProto* mutable_node(int index) { return g_host->GraphProto__mutable_node(this, index); }

GraphProto() = delete;
GraphProto(const GraphProto&) = delete;
Expand Down Expand Up @@ -148,6 +150,7 @@ struct NodeProto final {
void operator=(const NodeProto& v) { g_host->NodeProto__operator_assign(this, v); }
int attribute_size() { return g_host->NodeProto__attribute_size(this); }
const AttributeProto& attribute(int index) const { return g_host->NodeProto__attribute(this, index); }
AttributeProto* mutable_attribute(int index) { return g_host->NodeProto__mutable_attribute(this, index); }

NodeProto() = delete;
NodeProto(const NodeProto&) = delete;
Expand Down
229 changes: 229 additions & 0 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include <iostream>
#include <fstream>
#include <filesystem>

#include "onnx_ctx_model_helper.h"

Check warning on line 8 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L8

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:8:  Include the directory when naming header files  [build/include_subdir] [4]
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/framework/execution_provider.h"

namespace onnxruntime {

/*
* Check whether the graph has the EP context contrib op.
* The op can contain the precompiled engine info for TRT EP to directly load the engine.
*
* Note: Please see more details about "EPContext" contrib op in contrib_defs.cc
*/
bool GraphHasCtxNode(const GraphViewer& graph_viewer) {
for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
auto node = graph_viewer.GetNode(i);
if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
return true;
}
}
return false;
}

const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer) {
// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
while (cur_graph->IsSubgraph()) {
cur_graph = cur_graph->ParentGraph();
}

const Graph& main_graph = *cur_graph;
return main_graph.ModelPath();
}

std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path) {
std::filesystem::path base_path(path.ToPathString());
std::filesystem::path parent_path = base_path.parent_path();
std::filesystem::path engine_path = parent_path.append(engine_cache_path);
return engine_path;
}

/*
* Update ep_cache_context attribute of the EP context node with the given engine binary data
*/
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size) {
ONNX_NAMESPACE::GraphProto* graph_proto = model_proto->mutable_graph();
ONNX_NAMESPACE::NodeProto* node_proto = graph_proto->mutable_node(0);

for (int i = 0; i < node_proto->attribute_size(); ++i) {
ONNX_NAMESPACE::AttributeProto* attribute_proto = node_proto->mutable_attribute(i);
if (attribute_proto->name() == EP_CACHE_CONTEXT) {
std::string engine_data_str = "";
if (size > 0) {
engine_data_str.assign(engine_data, size);
}
attribute_proto->set_s(engine_data_str);
}
}
}

/*
* Create "EP context node" model where engine information is embedded
*/
ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
bool compute_capability_enable,
std::string compute_capability,
const logging::Logger* logger) {
auto model_build = graph_viewer.CreateModel(*logger);
auto& graph_build = model_build->MainGraph();

// Get graph inputs and outputs
std::vector<onnxruntime::NodeArg*> inputs, outputs;

Check warning on line 84 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L84

Add #include <vector> for vector<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:84:  Add #include <vector> for vector<>  [build/include_what_you_use] [4]
for (auto input : graph_viewer.GetInputs()) {
auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
inputs.push_back(&n_input);
}

for (auto output : graph_viewer.GetOutputs()) {
auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
outputs.push_back(&n_output);
}

// Create EP context node attributes
auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode
auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context
auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_architecture
std::string engine_data_str = "";
attr_0->set_name(EMBED_MODE);
attr_0->set_type(onnx::AttributeProto_AttributeType_INT);
attr_0->set_i(embed_mode);
attr_1->set_name(EP_CACHE_CONTEXT);
attr_1->set_type(onnx::AttributeProto_AttributeType_STRING);
if (embed_mode) {
if (size > 0) {
engine_data_str.assign(engine_data, size);
}
attr_1->set_s(engine_data_str);
} else {
attr_1->set_s(engine_cache_path);
}
auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
int num_attributes = compute_capability_enable ? 3 : 2;
node_attributes->reserve(num_attributes);
node_attributes->emplace(EMBED_MODE, *attr_0);
node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1);

if (compute_capability_enable) {
attr_2->set_name(COMPUTE_CAPABILITY);
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
attr_2->set_s(compute_capability);
node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);
}

// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
ORT_ENFORCE(graph_build.Resolve().IsOK());

// Serialize modelproto to string
auto new_graph_viewer = graph_build.CreateGraphViewer();
auto model = new_graph_viewer->CreateModel(*logger);
auto model_proto = model->ToProto();
new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

return model_proto.release();
}

/*
* Dump "EP context node" model
*
*/
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string engine_cache_path) {
std::fstream dump(engine_cache_path + "_wrapper.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path + "_wrapper.onnx";
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
}
auto node = graph_viewer.GetNode(0);
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
if (embed_mode) {
// Get engine from byte stream
const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s();
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(const_cast<char*>(context_binary.c_str()),

Check warning on line 162 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L162

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:162:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
static_cast<size_t>(context_binary.length())));

Check warning on line 163 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L163

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:163:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it";

Check warning on line 164 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L164

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:164:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!(*trt_engine_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from binary data");
}
} else {
// Get engine from cache file
std::ifstream engine_file(engine_cache_path_.string(), std::ios::binary | std::ios::in);
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);

Check warning on line 176 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L176

Using C-style cast. Use reinterpret_cast<char*>(...) instead [readability/casting] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:176:  Using C-style cast.  Use reinterpret_cast<char*>(...) instead  [readability/casting] [4]
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));

Check warning on line 177 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L177

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:177:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 177 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L177

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:177:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path_.string();
if (!(*trt_engine_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path_.string());
}
}
return Status::OK();
}

/*
* The sanity check for EP context contrib op.
*/
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) {
assert(graph_viewer.NumberOfNodes() == 1);
assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(0);
auto& attrs = node->GetAttributes();
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved

// Check hardware_architecture(compute_capability) if it's present as an attribute
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
if (model_compute_capability != compute_capability_) {
LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache doesn't match with the GPU's compute capability";

Check warning on line 200 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L200

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:200:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache: " << model_compute_capability;
LOGS_DEFAULT(ERROR) << "The compute capability of the GPU: " << compute_capability_;
return false;
}
}

// "embed_mode" attr and "ep_cache_context" attr should be present
if (attrs.count(EMBED_MODE) > 0 && attrs.count(EP_CACHE_CONTEXT) > 0) {
// ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0

Check warning on line 209 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L209

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:209:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
const int64_t embed_mode = attrs.at(EMBED_MODE).i();

// engine cache path
if (embed_mode == 0) {
// First assume engine cache path is relatvie to model path,
// If not, then assume the engine cache path is an absolute path.
engine_cache_path_ = LocateEngineRelativeToPath(attrs.at(EP_CACHE_CONTEXT).s(), GetModelPath(graph_viewer));
auto default_engine_cache_path_ = engine_cache_path_;
if (!std::filesystem::exists(engine_cache_path_)) {
engine_cache_path_.assign(attrs.at(EP_CACHE_CONTEXT).s());
if (!std::filesystem::exists(engine_cache_path_)) {
LOGS_DEFAULT(ERROR) << "Can't find " << default_engine_cache_path_.string() << " or " << engine_cache_path_.string() << " TensorRT engine";

Check warning on line 221 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L221

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:221:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
return false;
}
}
}
}
return true;
}
} // namespace onnxruntime
55 changes: 55 additions & 0 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#pragma once

#include <string>
#include <filesystem>

#include "NvInfer.h"
#include "core/providers/shared_library/provider_api.h"

namespace onnxruntime {

static const std::string EPCONTEXT_OP = "EPContext";

Check warning on line 14 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L14

For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_OP[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:14:  For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_OP[]".  [runtime/string] [4]
static const std::string EMBED_MODE = "embed_mode";

Check warning on line 15 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L15

For a static/global string constant, use a C style string instead: "static const char EMBED_MODE[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:15:  For a static/global string constant, use a C style string instead: "static const char EMBED_MODE[]".  [runtime/string] [4]
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";

Check warning on line 16 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L16

For a static/global string constant, use a C style string instead: "static const char EP_CACHE_CONTEXT[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:16:  For a static/global string constant, use a C style string instead: "static const char EP_CACHE_CONTEXT[]".  [runtime/string] [4]
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
static const std::string COMPUTE_CAPABILITY = "hardware_architecture";

Check warning on line 17 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L17

For a static/global string constant, use a C style string instead: "static const char COMPUTE_CAPABILITY[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:17:  For a static/global string constant, use a C style string instead: "static const char COMPUTE_CAPABILITY[]".  [runtime/string] [4]
static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";

Check warning on line 18 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L18

For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_OP_DOMAIN[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:18:  For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_OP_DOMAIN[]".  [runtime/string] [4]

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer);
std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path);
ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
bool compute_capability_enable,
std::string compute_capability,
const logging::Logger* logger);
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string engine_cache_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size);

class TensorRTCacheModelHandler {
public:
TensorRTCacheModelHandler(std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine,
nvinfer1::IRuntime* trt_runtime,
std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), compute_capability_(compute_capability) {

Check warning on line 41 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L41

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:41:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

bool ValidateEPCtxNode(const GraphViewer& graph_viewer);

Status GetEpContextFromGraph(const GraphViewer& graph_viewer);

private:
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine_;

Check warning on line 50 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L50

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:50:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
nvinfer1::IRuntime* trt_runtime_;
std::filesystem::path engine_cache_path_;
std::string compute_capability_;
}; // TRTCacheModelHandler
} // namespace onnxruntime
Loading
Loading