From 5bec52203d980d5925a8f806c7e7c922cf694636 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Sat, 20 Jul 2024 06:11:04 +0200 Subject: [PATCH] [TensorRT] Enable refitting an embedded engine when provided as byte stream (#21357) ### Description This allows refitting an engine using an ONNX file not available on disk. This is important for encrypted ONNX files on disk. --- .../tensorrt/tensorrt_provider_options.h | 7 ++ .../tensorrt/onnx_ctx_model_helper.cc | 24 +++++- .../tensorrt/onnx_ctx_model_helper.h | 6 ++ .../tensorrt/tensorrt_execution_provider.cc | 81 +++++++++++++----- .../tensorrt/tensorrt_execution_provider.h | 7 +- .../tensorrt_execution_provider_info.cc | 19 +++++ .../tensorrt_execution_provider_info.h | 2 + .../tensorrt/tensorrt_provider_factory.cc | 2 + .../core/session/provider_bridge_ort.cc | 4 + .../test/perftest/command_args_parser.cc | 6 +- onnxruntime/test/perftest/ort_test_session.cc | 18 +++- .../test/perftest/test_configuration.h | 1 + .../providers/tensorrt/tensorrt_basic_test.cc | 83 +++++++++++++++++-- 13 files changed, 225 insertions(+), 35 deletions(-) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index d008058821be3..816eaaf9bc71a 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -16,6 +16,7 @@ struct OrtTensorRTProviderOptionsV2 { int device_id{0}; // cuda device id. int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream. void* user_compute_stream{nullptr}; // user specified CUDA compute stream. + // can be updated using: UpdateTensorRTProviderOptionsWithValue int trt_max_partition_iterations{1000}; // maximum iterations for TensorRT parser to get capability int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs size_t trt_max_workspace_size{1 << 30}; // maximum workspace size for TensorRT. @@ -78,6 +79,12 @@ struct OrtTensorRTProviderOptionsV2 { const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for // the ONNX model containing the weights (applicable only when // the "trt_weight_stripped_engine_enable" option is enabled) + const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights + // (applicable only when the "trt_weight_stripped_engine_enable" + // option is enabled) + // can be updated using: UpdateTensorRTProviderOptionsWithValue + size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream" + // can be updated using: UpdateTensorRTProviderOptionsWithValue const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 42788f2960197..ef45d6c85d6a9 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -274,6 +274,9 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph auto& attrs = node->GetAttributes(); const int64_t embed_mode = attrs.at(EMBED_MODE).i(); + // Only make path checks if model not provided as byte buffer + bool make_secure_path_checks = !GetModelPath(graph_viewer).empty(); + if (embed_mode) { // Get engine from byte stream. const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s(); @@ -284,6 +287,23 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); } + + if (weight_stripped_engine_refit_) { + const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); + std::string placeholder; + auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + placeholder, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + (*trt_engine_).get(), + false /* serialize refitted engine to disk */, + detailed_build_log_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + } } else { // Get engine from cache file. std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s(); @@ -343,7 +363,9 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, weight_stripped_engine_cache, - true /* path check for security */, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, (*trt_engine_).get(), true /* serialize refitted engine to disk */, detailed_build_log_); diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index 3be08d043da48..3af0143cbf14e 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -52,6 +52,8 @@ class TensorRTCacheModelHandler { std::string compute_capability, bool weight_stripped_engine_refit, std::string onnx_model_folder_path, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, bool detailed_build_log) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), @@ -59,6 +61,8 @@ class TensorRTCacheModelHandler { compute_capability_(compute_capability), weight_stripped_engine_refit_(weight_stripped_engine_refit), onnx_model_folder_path_(onnx_model_folder_path), + onnx_model_bytestream_(onnx_model_bytestream), + onnx_model_bytestream_size_(onnx_model_bytestream_size), detailed_build_log_(detailed_build_log) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler); @@ -74,6 +78,8 @@ class TensorRTCacheModelHandler { std::string compute_capability_; bool weight_stripped_engine_refit_; std::string onnx_model_folder_path_; + const void* onnx_model_bytestream_; + size_t onnx_model_bytestream_size_; bool detailed_build_log_; }; // TRTCacheModelHandler } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 67cbc8f5d6f13..cdbb7bb2a8094 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1333,6 +1333,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv engine_cache_enable_ = info.engine_cache_enable; weight_stripped_engine_enable_ = info.weight_stripped_engine_enable; onnx_model_folder_path_ = info.onnx_model_folder_path; + onnx_model_bytestream_ = info.onnx_bytestream; + onnx_model_bytestream_size_ = info.onnx_bytestream_size; + if ((onnx_model_bytestream_ != nullptr && onnx_model_bytestream_size_ == 0) || + (onnx_model_bytestream_ == nullptr && onnx_model_bytestream_size_ != 0)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "When providing either 'trt_onnx_bytestream_size' or " + "'trt_onnx_bytestream' both have to be provided")); + } timing_cache_enable_ = info.timing_cache_enable; force_timing_cache_match_ = info.force_timing_cache; detailed_build_log_ = info.detailed_build_log; @@ -1757,7 +1765,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_ep_context_file_path: " << ep_context_file_path_ << ", trt_ep_context_embed_mode: " << ep_context_embed_mode_ << ", trt_cache_prefix: " << cache_prefix_ - << ", trt_engine_hw_compatible: " << engine_hw_compatible_; + << ", trt_engine_hw_compatible: " << engine_hw_compatible_ + << ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_; } TensorrtExecutionProvider::~TensorrtExecutionProvider() { @@ -2597,28 +2606,42 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, bool path_check, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log) { #if NV_TENSORRT_MAJOR >= 10 + bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; std::filesystem::path onnx_model_path{onnx_model_folder_path}; - onnx_model_path.append(onnx_model_filename); - if (path_check && IsAbsolutePath(onnx_model_path.string())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "For security purpose, the ONNX model path should be set with " - "a relative path, but it is an absolute path: " + - onnx_model_path.string()); - } - if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "The ONNX model path has '..'. For security purpose, it's not " - "allowed to point outside the directory."); - } + if (refit_from_file) { + if (!onnx_model_filename.empty()) { + onnx_model_path.append(onnx_model_filename); + } + if (onnx_model_path.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The ONNX model was not provided as path. " + "Please use provide an ONNX bytestream to enable refitting the weightless engine."); + } else { + // check if file path to ONNX is legal + if (path_check && IsAbsolutePath(onnx_model_path.string())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "For security purpose, the ONNX model path should be set with " + "a relative path, but it is an absolute path: " + + onnx_model_path.string()); + } + if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The ONNX model path has '..'. For security purpose, it's not " + "allowed to point outside the directory."); + } - if (!std::filesystem::exists(onnx_model_path)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "The ONNX model " + onnx_model_path.string() + - " does not exist."); + if (!(std::filesystem::exists(onnx_model_path) && std::filesystem::is_regular_file(onnx_model_path))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The ONNX model " + onnx_model_path.string() + + " does not exist."); + } + } } // weight-stripped engine refit logic @@ -2626,9 +2649,18 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); auto parser_refitter = std::unique_ptr( nvonnxparser::createParserRefitter(*refitter, trt_logger)); - if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + if (refit_from_file) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from file on disk: " << onnx_model_path.string(); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + } + } else { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from byte array"; + if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestraem"); + } } if (refitter->refitCudaEngine()) { LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine."; @@ -3212,10 +3244,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } if (weight_stripped_engine_refit_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build"; + char* onnx = string_buf.data(); + size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, + onnx, + onnx_size, trt_engine.get(), true /* serialize refitted engine to disk */, detailed_build_log_); @@ -3685,6 +3722,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView onnx_model_folder_path_, engine_cache_path, false /* path check for security */, + onnx_model_bytestream_, + onnx_model_bytestream_size_, trt_engine, true /* serialize refitted engine to disk */, detailed_build_log_); @@ -3910,6 +3949,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con compute_capability_, weight_stripped_engine_enable_, onnx_model_folder_path_, + onnx_model_bytestream_, + onnx_model_bytestream_size_, detailed_build_log_); auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); if (status != Status::OK()) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index b58e86237860c..3f20314438564 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -274,13 +274,12 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; - /** - * Refit the weight-stripped engine - */ static common::Status RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, std::string& weight_stripped_engine_cath_path, bool path_check, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log); @@ -305,6 +304,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool weight_stripped_engine_enable_ = false; bool weight_stripped_engine_refit_ = false; std::string onnx_model_folder_path_; + const void* onnx_model_bytestream_; + size_t onnx_model_bytestream_size_; bool build_heuristics_enable_ = false; bool sparsity_enable_ = false; int builder_optimization_level_ = 3; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index 9fe39f5921e1c..63b6d35072290 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -54,6 +54,8 @@ constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode"; constexpr const char* kEpContextFilePath = "trt_ep_context_file_path"; constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; +constexpr const char* kONNXBytestream = "trt_onnx_bytestream"; +constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size"; } // namespace provider_option_names } // namespace tensorrt @@ -61,6 +63,7 @@ constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { TensorrtExecutionProviderInfo info{}; void* user_compute_stream = nullptr; + void* onnx_bytestream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -122,10 +125,20 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) .AddAssignmentToReference(tensorrt::provider_option_names::kEngineHwCompatible, info.engine_hw_compatible) + .AddValueParser( + tensorrt::provider_option_names::kONNXBytestream, + [&onnx_bytestream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + onnx_bytestream = reinterpret_cast(address); + return Status::OK(); + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; info.has_user_compute_stream = (user_compute_stream != nullptr); + info.onnx_bytestream = onnx_bytestream; return info; } @@ -173,6 +186,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kEpContextFilePath, MakeStringWithClassicLocale(info.ep_context_file_path)}, {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)}, {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)}, + {tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)}, + {tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)}, }; return options; } @@ -234,6 +249,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)}, {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)}, {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)}, + {tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast(info.trt_onnx_bytestream))}, + {tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)}, }; return options; } @@ -336,5 +353,7 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_ep_context_embed_mode = internal_options.ep_context_embed_mode; trt_provider_options_v2.trt_ep_context_file_path = copy_string_if_needed(internal_options.ep_context_file_path); trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible; + trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream; + trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 3b859ea2da466..50b934fd5fcbc 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -34,6 +34,8 @@ struct TensorrtExecutionProviderInfo { std::string engine_cache_path{""}; bool weight_stripped_engine_enable{false}; std::string onnx_model_folder_path{""}; + const void* onnx_bytestream{nullptr}; + size_t onnx_bytestream_size{0}; bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 6430ffab09976..e242788ff389a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -116,6 +116,8 @@ struct Tensorrt_Provider : Provider { info.ep_context_embed_mode = options.trt_ep_context_embed_mode; info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix; info.engine_hw_compatible = options.trt_engine_hw_compatible != 0; + info.onnx_bytestream = options.trt_onnx_bytestream; + info.onnx_bytestream_size = options.trt_onnx_bytestream_size; return std::make_shared(info); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 4f9669a7dcc4c..1d21933e9cba9 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2465,6 +2465,10 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptionsWithValue, if (strcmp(key, "user_compute_stream") == 0) { tensorrt_options->has_user_compute_stream = 1; tensorrt_options->user_compute_stream = value; + } else if (strcmp(key, "trt_onnx_bytestream") == 0) { + tensorrt_options->trt_onnx_bytestream = value; + } else if (strcmp(key, "trt_onnx_bytestream_size") == 0) { + tensorrt_options->trt_onnx_bytestream_size = *reinterpret_cast(value); } return nullptr; #else diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index b7c99fa66a1ea..e6d4e0a94abd3 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -143,6 +143,7 @@ namespace perftest { "\t-D [Disable thread spinning]: disable spinning entirely for thread owned by onnxruntime intra-op thread pool.\n" "\t-Z [Force thread to stop spinning between runs]: disallow thread from spinning during runs to reduce cpu usage.\n" "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n" + "\t-l Provide file as binary in memory by using fopen before session creation.\n" "\t-h: help\n"); } #ifdef _WIN32 @@ -205,7 +206,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqzn"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznl"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -390,6 +391,9 @@ static bool ParseSessionConfigs(const std::string& configs_string, case 'n': test_config.run_config.exit_after_session_creation = true; break; + case 'l': + test_config.model_info.load_via_path = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index ff782da35cbe6..92d732fba2a0a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -5,6 +5,7 @@ #include "ort_test_session.h" #include #include +#include #include #include #include @@ -816,8 +817,21 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } - session_ = Ort::Session(env, performance_test_config.model_info.model_file_path.c_str(), session_options); - + if (!performance_test_config.model_info.load_via_path) { + session_ = Ort::Session(env, performance_test_config.model_info.model_file_path.c_str(), session_options); + } else { + std::ifstream file(performance_test_config.model_info.model_file_path.c_str(), + std::ios::binary | std::ios::in | std::ios::ate); + if (file.is_open()) { + const std::streamsize fsize = file.tellg(); + file.seekg(0, std::ios_base::beg); + std::vector model_bytes(narrow(fsize)); + file.read(model_bytes.data(), fsize); + session_ = Ort::Session(env, model_bytes.data(), model_bytes.size(), session_options); + } else { + ORT_THROW("Model file could not be opened.\n"); + } + } size_t output_count = session_.GetOutputCount(); output_names_.resize(output_count); Ort::AllocatorWithDefaultOptions a; diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 70a6b12690d5d..209fb55fe93d4 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -29,6 +29,7 @@ struct ModelInfo { std::basic_string model_file_path; std::basic_string input_file_path; std::basic_string result_file_path; + bool load_via_path = false; }; struct MachineConfig { diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 2b5b82d0fc16a..63327a028c6f4 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -122,6 +122,18 @@ void CreateBaseModel(const PathString& model_name, status = onnxruntime::Model::Save(model, model_name); } +std::vector ReadFileFromDisk(const PathString& path) { + std::fstream file(path.c_str(), std::fstream::binary | std::fstream::in | std::fstream::ate); + std::vector file_bytes; + if (file.is_open()) { + auto fsize = file.tellg(); + file.seekg(0, std::ios_base::beg); + file_bytes.resize(fsize); + file.read(file_bytes.data(), fsize); + } + return file_bytes; +} + bool HasCacheFileWithPrefix(const std::string& prefix, std::string file_dir = "") { std::filesystem::path target_dir; if (file_dir.empty()) { @@ -360,7 +372,8 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { } TEST(TensorrtExecutionProviderTest, EPContextNode) { - PathString model_name = ORT_TSTR("EPContextNode_test.onnx"); + std::string model_name_str = "EPContextNode_test.onnx"; + PathString model_name = ToPathString(model_name_str); std::string graph_name = "EPContextNode_test"; std::string sess_log_id = "EPContextNode_test"; std::vector dims = {1, 3, 2}; @@ -461,11 +474,11 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { */ InferenceSession session_object3{so, GetEnvironment()}; OrtTensorRTProviderOptionsV2 params3; - model_name = ToPathString(params.trt_ep_context_file_path); + PathString ctx_model_name = ToPathString(params.trt_ep_context_file_path); params3.trt_engine_cache_enable = 1; execution_provider = TensorrtExecutionProviderWithOptions(¶ms3); EXPECT_TRUE(session_object3.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - status = session_object3.Load(model_name); + status = session_object3.Load(ctx_model_name); ASSERT_TRUE(status.IsOK()); status = session_object3.Initialize(); ASSERT_TRUE(status.IsOK()); @@ -490,10 +503,10 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { */ InferenceSession session_object4{so, GetEnvironment()}; OrtTensorRTProviderOptionsV2 params4; - model_name = ORT_TSTR("./context_model_folder/EPContextNode_test_ctx.onnx"); + ctx_model_name = ToPathString("./context_model_folder/EPContextNode_test_ctx.onnx"); execution_provider = TensorrtExecutionProviderWithOptions(¶ms4); EXPECT_TRUE(session_object4.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - status = session_object4.Load(model_name); + status = session_object4.Load(ctx_model_name); ASSERT_TRUE(status.IsOK()); status = session_object4.Initialize(); ASSERT_TRUE(status.IsOK()); @@ -514,7 +527,6 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { params5.trt_dump_ep_context_model = 1; params5.trt_ep_context_embed_mode = 1; params5.trt_ep_context_file_path = "EP_Context_model_2.onnx"; - model_name = ORT_TSTR("EPContextNode_test.onnx"); execution_provider = TensorrtExecutionProviderWithOptions(¶ms5); EXPECT_TRUE(session_object5.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); status = session_object5.Load(model_name); @@ -528,10 +540,10 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { InferenceSession session_object6{so, GetEnvironment()}; OrtTensorRTProviderOptionsV2 params6; params6.trt_ep_context_embed_mode = 1; - model_name = ToPathString(params5.trt_ep_context_file_path); + ctx_model_name = ToPathString(params5.trt_ep_context_file_path); execution_provider = TensorrtExecutionProviderWithOptions(¶ms6); EXPECT_TRUE(session_object6.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - status = session_object6.Load(model_name); + status = session_object6.Load(ctx_model_name); ASSERT_TRUE(status.IsOK()); status = session_object6.Initialize(); ASSERT_TRUE(status.IsOK()); @@ -543,6 +555,61 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { // Y: 1, 3, 3, 2, 2, 2 // Z: 1, 3, 3, 2, 2, 2 RunSession(session_object6, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 7: Run context model with ONNX in memory + */ + auto model_bytes = ReadFileFromDisk(model_name); + std::string ctx_model_name_str = "EP_Context_model_weight_stripped.onnx"; + ctx_model_name = ToPathString(ctx_model_name_str); + InferenceSession session_object7{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params7; + params7.trt_dump_ep_context_model = 1; + params7.trt_ep_context_embed_mode = 1; + params7.trt_weight_stripped_engine_enable = 1; + params7.trt_ep_context_file_path = ctx_model_name_str.c_str(); + execution_provider = TensorrtExecutionProviderWithOptions(¶ms7); + EXPECT_TRUE(session_object7.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object7.Load(model_bytes.data(), static_cast(model_bytes.size())); + ASSERT_TRUE(status.IsOK()); + status = session_object7.Initialize(); + std::cerr << status.ErrorMessage(); + ASSERT_TRUE(status.IsOK()); + RunSession(session_object7, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 7: Refit weightless context model with ONNX in memory + */ + auto ctx_model_bytes = ReadFileFromDisk(ctx_model_name); + InferenceSession session_object8{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params8; + params8.trt_weight_stripped_engine_enable = 1; + params8.trt_onnx_bytestream = model_bytes.data(); + params8.trt_onnx_bytestream_size = model_bytes.size(); + execution_provider = TensorrtExecutionProviderWithOptions(¶ms8); + EXPECT_TRUE(session_object8.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object8.Load(ctx_model_bytes.data(), static_cast(ctx_model_bytes.size())); + std::cerr << status.ErrorMessage(); + ASSERT_TRUE(status.IsOK()); + status = session_object8.Initialize(); + std::cerr << status.ErrorMessage(); + ASSERT_TRUE(status.IsOK()); + RunSession(session_object8, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 7: Refit weightless context model with ONNX from disk + */ + InferenceSession session_object9{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params9; + params9.trt_weight_stripped_engine_enable = 1; + params9.trt_onnx_model_folder_path = model_name_str.c_str(); + execution_provider = TensorrtExecutionProviderWithOptions(¶ms9); + EXPECT_TRUE(session_object9.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object9.Load(ctx_model_bytes.data(), static_cast(ctx_model_bytes.size())); + ASSERT_TRUE(status.IsOK()); + status = session_object9.Initialize(); + ASSERT_TRUE(status.IsOK()); + RunSession(session_object9, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); } TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) {