Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Load precompiled TRT engine file directly #18217

Merged
merged 53 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
f8099b1
update
chilo-ms Oct 31, 2023
ee81de5
fix bug
chilo-ms Nov 1, 2023
452a629
update
chilo-ms Nov 1, 2023
6f18d8d
remove redundant check
chilo-ms Nov 1, 2023
9430768
remove unused variable
chilo-ms Nov 1, 2023
e9507e5
add script to generate epcontext node
chilo-ms Nov 1, 2023
35a4d33
fix bug
chilo-ms Nov 1, 2023
44a7cc5
update
chilo-ms Nov 1, 2023
b2fdb06
update
chilo-ms Nov 2, 2023
eeb6552
refactor
chilo-ms Nov 2, 2023
1b5117d
update
chilo-ms Nov 2, 2023
df7ef46
update
chilo-ms Nov 3, 2023
993b2ad
change function name
chilo-ms Nov 3, 2023
34a86d7
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Nov 3, 2023
9631f73
update
chilo-ms Nov 4, 2023
93f9fbb
update
chilo-ms Nov 5, 2023
d5974fc
refactor
chilo-ms Nov 6, 2023
7202b73
check compute capability
chilo-ms Nov 6, 2023
60f6e7e
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Nov 6, 2023
5838143
update for reading engine byte data
chilo-ms Nov 6, 2023
aea26b1
add script to generate engine wrapper onnx model
chilo-ms Nov 7, 2023
f4b38f7
refactor
chilo-ms Nov 7, 2023
de9f510
fix bug
chilo-ms Nov 7, 2023
65331ed
fix format
chilo-ms Nov 7, 2023
befba02
refactor
chilo-ms Nov 7, 2023
5103dda
add unit test
chilo-ms Nov 7, 2023
f933379
refactor script
chilo-ms Nov 7, 2023
11fd212
fix format
chilo-ms Nov 7, 2023
561b059
update
chilo-ms Nov 9, 2023
789efe8
fix format
chilo-ms Nov 9, 2023
a5843c2
fix format
chilo-ms Nov 9, 2023
11ce3fc
fix gen_trt_engine_wrapper_onnx_model.py
chilo-ms Nov 10, 2023
f9206f3
refactor unit test
chilo-ms Nov 10, 2023
0bd5b8c
update
chilo-ms Nov 10, 2023
e2c3f16
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Nov 19, 2023
6605fd4
update script
chilo-ms Nov 19, 2023
699d538
fix format
chilo-ms Nov 19, 2023
1cbe3e9
fix bug for conflict resolve
chilo-ms Nov 20, 2023
f8e775d
refactor script
chilo-ms Nov 20, 2023
8f7c7ac
generate ep context node model from TRT EP
chilo-ms Nov 23, 2023
77a62f2
add trt_dump_ep_context_model, trt_ep_context_embed_mode, trt_ep_cont…
chilo-ms Nov 23, 2023
04baad7
fix format
chilo-ms Nov 24, 2023
c3a028b
swap the position of CreateNodeComputeFromGraph and CreateNodeCompute…
chilo-ms Jan 8, 2024
db46b64
merge PR 18879 and PR 18834
chilo-ms Jan 9, 2024
842cdf0
merge PR 18879 and PR 18834 (continue)
chilo-ms Jan 9, 2024
ca8d49f
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Jan 9, 2024
e0d3346
remove kernelcontext_setoutput
chilo-ms Jan 9, 2024
f9231a5
fix bugs after merge main
chilo-ms Jan 9, 2024
cebfcd8
apply lintrunner -a
chilo-ms Jan 10, 2024
b1c4305
Use 'hardware_architecture'
chilo-ms Jan 11, 2024
0435971
merge main and also enforce get compute capability once inside ep con…
chilo-ms Jan 11, 2024
77bf077
Merge branch 'main' into chi/trt_engine_wrapper
chilo-ms Jan 11, 2024
28bdd0a
remove unnecessary code
chilo-ms Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/onnxruntime/core/framework/op_kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ class OpKernelContext {
*/
AllocatorPtr GetAllocator(const OrtDevice& device) const;

#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
Status SetOutputMLValue(int index, const OrtValue& ort_value);
#endif

protected:
OpKernelContext(concurrency::ThreadPool* threadpool, const logging::Logger& logger, Stream* stream);

Expand All @@ -195,10 +199,6 @@ class OpKernelContext {
const OrtValue* GetImplicitInputMLValue(int index) const;
OrtValue* GetOutputMLValue(int index);

#ifdef ENABLE_ATEN
Status SetOutputMLValue(int index, const OrtValue& ort_value);
#endif

// Creates the OrtValue* based on the shape, if it does not exist
virtual OrtValue* OutputMLValue(int index, const TensorShape& shape);

Expand Down
9 changes: 9 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4512,6 +4512,15 @@ struct OrtApi {
* \since Version 1.17.
*/
ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);

/** \brief Used for custom operators, set an output of a kernel
*
* \see ::OrtCustomOp
*
* \since Version 1.17.
*/
ORT_API2_STATUS(KernelContext_SetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index,
_In_ const OrtValue* ort_value);
};

/*
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,7 @@ struct KernelContext {
ConstValue GetInput(size_t index) const;
UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
void SetOutput(size_t index, const OrtValue& ort_value);
void* GetGPUComputeStream() const;
Logger GetLogger() const;
OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,10 @@ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int
return UnownedValue(out);
}

inline void KernelContext::SetOutput(size_t index, const OrtValue& ort_value) {
Ort::ThrowOnError(GetApi().KernelContext_SetOutput(ctx_, index, &ort_value));
}

inline void* KernelContext::GetGPUComputeStream() const {
void* out = nullptr;
Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ IExecutionFrame::IExecutionFrame(const OrtValueNameIdxMap& ort_value_idx_map,

IExecutionFrame::~IExecutionFrame() = default;

#ifdef ENABLE_ATEN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
Status IExecutionFrame::SetOutputMLValue(int index, const OrtValue& ort_value) {
int ort_value_idx = GetNodeIdxToMLValueIdx(index);
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/execution_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class IExecutionFrame {
const OrtValue* GetNodeInputOrOutputMLValue(int index) const;
OrtValue* GetMutableNodeInputOrOutputMLValue(int index);

#ifdef ENABLE_ATEN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
// Override the index-th output with ort_value
Status SetOutputMLValue(int index, const OrtValue& ort_value);
#endif
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/op_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ AllocatorPtr OpKernelContext::GetAllocator(const OrtDevice& device) const {
return execution_frame_->GetAllocator(device);
}

#ifdef ENABLE_ATEN
#if defined(ENABLE_ATEN) || defined(USE_TENSORRT)
Status OpKernelContext::SetOutputMLValue(int index, const OrtValue& ort_value) {
if (index < 0 || index >= OutputCount()) {
return Status(common::ONNXRUNTIME, common::FAIL,
Expand Down
104 changes: 104 additions & 0 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include <iostream>
#include <fstream>
#include <filesystem>

#include "onnx_ctx_model_helper.h"

namespace onnxruntime {

/*
* Check whether the graph has the EP context contrib op.
* The op can contain the precompiled engine info for TRT EP to directly load the engine.
*
* Note: Please see more details about "EPContext" contrib op in contrib_defs.cc
*/
bool GraphHasCtxNode(const GraphViewer& graph_viewer) {
for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
auto node = graph_viewer.GetNode(i);
if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
return true;
}
}
return false;
}

const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer) {
// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
while (cur_graph->IsSubgraph()) {
cur_graph = cur_graph->ParentGraph();
}

const Graph& main_graph = *cur_graph;
return main_graph.ModelPath();
}

std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path) {
std::filesystem::path base_path(path.ToPathString());
std::filesystem::path parent_path = base_path.parent_path();
std::filesystem::path engine_path = parent_path.append(engine_cache_path);
return engine_path;
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
assert(graph_viewer.NumberOfNodes() == 1);
assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(0);
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
if (embed_mode) {
const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s();
std::cout << context_binary << std::endl;
} else {
std::ifstream engine_file(engine_cache_path_.string(), std::ios::binary | std::ios::in);
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path_.string();
if (!(*trt_engine_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path_.string());
}
}
return Status::OK();
}

/*
* The sanity check for EP context contrib op.
*/
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) {
assert(graph_viewer.NumberOfNodes() == 1);
assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(0);
auto& attrs = node->GetAttributes();
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved

// "embed_mode" attr and "ep_cache_context" attr should be present
if (attrs.count(EMBED_MODE) > 0 && attrs.count(EP_CACHE_CONTEXT) > 0) {
// ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0
const int64_t embed_mode = attrs.at(EMBED_MODE).i();

// engine cache path
if (embed_mode == 0) {
// First assume engine cache path is relatvie to model path,
// If not, then assume the engine cache path is an absolute path.
engine_cache_path_ = LocateEngineRelativeToPath(attrs.at(EP_CACHE_CONTEXT).s(), GetModelPath(graph_viewer));
auto default_engine_cache_path_ = engine_cache_path_;
if (!std::filesystem::exists(engine_cache_path_)) {
engine_cache_path_.assign(attrs.at(EP_CACHE_CONTEXT).s());
if (!std::filesystem::exists(engine_cache_path_)) {
LOGS_DEFAULT(ERROR) << "Can't find " << default_engine_cache_path_.string() << " or " << engine_cache_path_.string() << " TensorRT engine";
return false;
}
}
}
}
return true;
}
} // namespace onnxruntime
42 changes: 42 additions & 0 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#pragma once

#include <string>
#include <filesystem>

#include "NvInfer.h"
#include "core/providers/shared_library/provider_api.h"

namespace onnxruntime {

static const std::string EPCONTEXT_OP = "EPContext";
static const std::string MAIN_CONTEXT = "main_context";
static const std::string EMBED_MODE = "embed_mode";
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
static const std::string EP_SDK_VER = "ep_sdk_version";
static const std::string PARTITION_NAME = "partition_name";
static const std::string SOURCE = "source";

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer);
std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path);

class TensorRTCacheModelHandler {
public:
TensorRTCacheModelHandler(std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine,
nvinfer1::IRuntime* trt_runtime) : trt_engine_(trt_engine), trt_runtime_(trt_runtime) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

bool ValidateEPCtxNode(const GraphViewer& graph_viewer);

Status GetEpContextFromGraph(const GraphViewer& graph_viewer);

private:
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine_;
nvinfer1::IRuntime* trt_runtime_;
std::filesystem::path engine_cache_path_;
}; // TRTCacheModelHandler
} // namespace onnxruntime
Loading
Loading