From 1274f06c03310643d9b5654c7322709ad8d1b87d Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 22 May 2023 17:02:09 +0000 Subject: [PATCH 01/15] update --- .../tensorrt/tensorrt_provider_options.h | 1 + .../tensorrt/tensorrt_execution_provider.cc | 83 ++++++++++++++++++- .../tensorrt/tensorrt_execution_provider.h | 26 ++++++ .../tensorrt_execution_provider_info.cc | 4 + .../tensorrt_execution_provider_info.h | 1 + .../tensorrt/tensorrt_provider_factory.cc | 3 + .../core/session/provider_bridge_ort.cc | 2 + .../python/onnxruntime_pybind_state.cc | 10 ++- onnxruntime/test/perftest/ort_test_session.cc | 12 ++- onnxruntime/test/providers/cpu/model_tests.cc | 2 +- .../providers/tensorrt/tensorrt_basic_test.cc | 12 ++- 11 files changed, 148 insertions(+), 8 deletions(-) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 600e255bcdf9f..e7d0f9f03ade9 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -44,4 +44,5 @@ struct OrtTensorRTProviderOptionsV2 { const char* trt_profile_min_shapes; // Specify the range of the input shapes to build the engine with const char* trt_profile_max_shapes; // Specify the range of the input shapes to build the engine with const char* trt_profile_opt_shapes; // Specify the range of the input shapes to build the engine with + int trt_cuda_graph_enable; // Enable CUDA graph in ORT TRT }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c70099fbf419d..27cf7bdce20f9 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -703,6 +703,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv profile_min_shapes = info.profile_min_shapes; profile_max_shapes = info.profile_max_shapes; profile_opt_shapes = info.profile_opt_shapes; + cuda_graph_enable_ = info.cuda_graph_enable; } else { try { const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); @@ -842,6 +843,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv profile_min_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMinShapes); profile_max_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesMaxShapes); profile_opt_shapes = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kProfilesOptShapes); + + const std::string cuda_graph_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCudaGraphEnable); + if (!cuda_graph_enable_env.empty()) { + cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true); + } } catch (const std::invalid_argument& ex) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what(); } catch (const std::out_of_range& ex) { @@ -895,6 +901,13 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } +#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 + if (cuda_graph_enable_) { + cuda_graph_ = std::make_unique(); + cuda_graph_->SetStream(stream_); + } +#endif + /* * Parse explicit min/max/opt profile shapes from provider options. * @@ -968,7 +981,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_tactic_sources: " << tactic_sources_ << ", trt_profile_min_shapes: " << profile_min_shapes << ", trt_profile_max_shapes: " << profile_max_shapes - << ", trt_profile_opt_shapes: " << profile_opt_shapes; + << ", trt_profile_opt_shapes: " << profile_opt_shapes + << ", trt_cuda_graph_enable: " << cuda_graph_enable_; } TensorrtExecutionProvider::~TensorrtExecutionProvider() { @@ -982,6 +996,45 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); } +#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 +bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const { + return cuda_graph_enable_; +} + +bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void TensorrtExecutionProvider::CaptureBegin() { + cuda_graph_->Reset(); + cuda_graph_->CaptureBegin(); +} + +void TensorrtExecutionProvider::CaptureEnd() { + cuda_graph_->CaptureEnd(); + is_graph_captured_ = true; +} + +bool TensorrtExecutionProvider::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status TensorrtExecutionProvider::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + // Please note that CUDAGraph::Replay() is not thread safe. + // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), + // therefore calling CUDAGraph::Replay() is guaranteed to be thread safe. + return cuda_graph_->Replay(); +} + +void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + // Please note that this function is not thread safe. + // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), + // therefore following increment is guaranteed to be thread safe. + ++regular_run_count_before_graph_capture_; +} +#endif + AllocatorPtr TensorrtExecutionProvider::GetAllocator(OrtMemType mem_type) const { if (mem_type == OrtMemTypeDefault) { return allocator_; @@ -1063,7 +1116,17 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTransfer() cons return onnxruntime::CreateGPUDataTransfer(); } +Status TensorrtExecutionProvider::OnRunStart() { + std::cout << "OnRunStart() ..." << std::endl; + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + CaptureBegin(); + } + return Status::OK(); +} + Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { + std::cout << "OnRunEnd() ..." << std::endl; if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } @@ -2178,6 +2241,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(state); @@ -2829,6 +2894,22 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector= 10000 + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; +#endif + private: TensorrtExecutionProviderInfo info_; bool external_stream_ = false; @@ -209,6 +220,14 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool timing_cache_enable_ = false; bool force_timing_cache_match_ = false; bool detailed_build_log_ = false; + bool cuda_graph_enable_ = false; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 + std::unique_ptr cuda_graph_; // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph) + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. +#endif std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; std::unordered_map> parsers_; @@ -259,5 +278,12 @@ class TensorrtExecutionProvider : public IExecutionProvider { /**Check whether all the nodes of subgraph are supported*/ bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 + bool IsGraphCaptureAllowed() const; + void CaptureBegin(); + void CaptureEnd(); + void IncrementRegularRunCountBeforeGraphCapture(); +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index 44af3a236500f..de415c7332390 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -43,6 +43,7 @@ constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths"; constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes"; constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes"; constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes"; +constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable"; } // namespace provider_option_names } // namespace tensorrt @@ -91,6 +92,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMinShapes, info.profile_min_shapes) .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) .Parse(options)); // add new provider option here. return info; @@ -129,6 +131,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kProfilesMinShapes, MakeStringWithClassicLocale(info.profile_min_shapes)}, {tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, {tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, + {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, }; return options; } @@ -167,6 +170,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kBuilderOptimizationLevel, MakeStringWithClassicLocale(info.trt_builder_optimization_level)}, {tensorrt::provider_option_names::kAuxiliaryStreams, MakeStringWithClassicLocale(info.trt_auxiliary_streams)}, {tensorrt::provider_option_names::kTacticSources, kTacticSources_}, + {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)}, }; return options; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index 7235bb6940f9c..4fb9837e1c040 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -49,6 +49,7 @@ struct TensorrtExecutionProviderInfo { std::string profile_min_shapes{""}; std::string profile_max_shapes{""}; std::string profile_opt_shapes{""}; + bool cuda_graph_enable{false}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 5a1b662078e90..ed5cca93f74d4 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -94,6 +94,7 @@ struct Tensorrt_Provider : Provider { info.profile_min_shapes = options.trt_profile_min_shapes == nullptr ? "" : options.trt_profile_min_shapes; info.profile_max_shapes = options.trt_profile_max_shapes == nullptr ? "" : options.trt_profile_max_shapes; info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes; + info.cuda_graph_enable = options.trt_cuda_graph_enable != 0; common::Status status = CreateTensorRTCustomOpDomainList(info); if (!status.IsOK()) { @@ -229,6 +230,8 @@ struct Tensorrt_Provider : Provider { dest[str_size] = '\0'; trt_options.trt_profile_opt_shapes = (const char*)dest; } + + trt_options.trt_cuda_graph_enable = internal_options.cuda_graph_enable; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 94d4f4daa6abf..4fd39d1ef7422 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1312,6 +1312,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti trt_options_converted.trt_profile_min_shapes = ""; trt_options_converted.trt_profile_max_shapes = ""; trt_options_converted.trt_profile_opt_shapes = ""; + trt_options_converted.trt_cuda_graph_enable = 0; return trt_options_converted; } @@ -1668,6 +1669,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT (*out)->trt_profile_min_shapes = nullptr; (*out)->trt_profile_max_shapes = nullptr; (*out)->trt_profile_opt_shapes = nullptr; + (*out)->trt_cuda_graph_enable = false; return nullptr; #else ORT_UNUSED_PARAMETER(out); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index cf709e508dbf3..ffa810475ebaa 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -380,7 +380,7 @@ std::unique_ptr CreateExecutionProviderInstance( nullptr, nullptr, nullptr, - nullptr}; + 0}; for (auto option : it->second) { if (option.first == "device_id") { if (!option.second.empty()) { @@ -598,6 +598,14 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a string of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'.\n"); } + } else if (option.first == "trt_cuda_graph_enable") { + if (option.second == "True" || option.second == "true") { + params.trt_cuda_graph_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.trt_cuda_graph_enable = false; + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be 'True' or 'False'. Default value is 'False'.\n"); + } } else { ORT_THROW("Invalid TensorRT EP option: ", option.first); } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index ae3f4ca05aadb..53b1799ea6d78 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -133,6 +133,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string trt_profile_min_shapes = ""; std::string trt_profile_max_shapes = ""; std::string trt_profile_opt_shapes = ""; + bool trt_cuda_graph_enable = false; #ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); @@ -362,8 +363,16 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_profile_opt_shapes' should be a non-empty string.\n"); } + } else if (key == "trt_cuda_graph_enable") { + if (value == "true" || value == "True") { + trt_cuda_graph_enable = true; + } else if (value == "false" || value == "False") { + trt_cuda_graph_enable = false; + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be a boolean i.e. true or false. Default value is false.\n"); + } } else { - ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_timing_cache_enable', 'trt_force_timing_cache', 'trt_detailed_build_log', 'trt_build_heuristics_enable', 'trt_sparsity_enable', 'trt_builder_optimization_level', 'trt_auxiliary_streams', 'trt_tactic_sources', 'trt_extra_plugin_lib_paths', 'trt_profile_min_shapes', 'trt_profile_max_shapes', 'trt_profile_opt_shapes'] \n"); + ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_context_memory_sharing_enable', 'trt_layer_norm_fp32_fallback', 'trt_timing_cache_enable', 'trt_force_timing_cache', 'trt_detailed_build_log', 'trt_build_heuristics_enable', 'trt_sparsity_enable', 'trt_builder_optimization_level', 'trt_auxiliary_streams', 'trt_tactic_sources', 'trt_extra_plugin_lib_paths', 'trt_profile_min_shapes', 'trt_profile_max_shapes', 'trt_profile_opt_shapes', 'trt_cuda_graph_enable'] \n"); } } OrtTensorRTProviderOptionsV2 tensorrt_options; @@ -399,6 +408,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device tensorrt_options.trt_profile_min_shapes = trt_profile_min_shapes.c_str(); tensorrt_options.trt_profile_max_shapes = trt_profile_max_shapes.c_str(); tensorrt_options.trt_profile_opt_shapes = trt_profile_opt_shapes.c_str(); + tensorrt_options.trt_cuda_graph_enable = trt_cuda_graph_enable; session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options); diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index a4d37a1e6cd99..46d68e1680f5a 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -704,7 +704,7 @@ TEST_P(ModelTest, Run) { OrtTensorRTProviderOptionsV2 params{0, 0, nullptr, 1000, 1, 1 << 30, 1, // enable fp16 0, nullptr, 0, 0, 0, 0, 0, nullptr, 0, nullptr, 0, 0, 0, 0, 0, 0, 0, 0, - 3, -1, nullptr, nullptr, nullptr, nullptr, nullptr}; + 3, -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0}; ortso.AppendExecutionProvider_TensorRT_V2(params); } else { diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 81cb844f68f48..e92debf89210f 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -165,7 +165,8 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string nullptr, nullptr, nullptr, - nullptr}; + nullptr, + 0}; params.trt_engine_cache_enable = 1; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); @@ -248,7 +249,8 @@ void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string nullptr, nullptr, nullptr, - nullptr}; + nullptr, + 0}; params.trt_engine_cache_enable = 1; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); @@ -402,7 +404,8 @@ TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { nullptr, nullptr, nullptr, - nullptr}; + nullptr, + 0}; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); @@ -497,7 +500,8 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { nullptr, nullptr, nullptr, - nullptr}; + nullptr, + 0}; if (cache_type.compare("engine") == 0) { /* Following code block tests the functionality of engine and optimization profile of ORT TRT, including: From 007b64c3be572c09e822ad8b258a803a603a009a Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 22 May 2023 17:08:56 +0000 Subject: [PATCH 02/15] update --- .../providers/tensorrt/tensorrt_execution_provider.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 27cf7bdce20f9..6a3f96c53f56a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1117,7 +1117,6 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTransfer() cons } Status TensorrtExecutionProvider::OnRunStart() { - std::cout << "OnRunStart() ..." << std::endl; if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; CaptureBegin(); @@ -1126,7 +1125,6 @@ Status TensorrtExecutionProvider::OnRunStart() { } Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { - std::cout << "OnRunEnd() ..." << std::endl; if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } @@ -2241,7 +2239,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(state); @@ -2895,10 +2891,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector Date: Wed, 24 May 2023 00:19:58 +0000 Subject: [PATCH 03/15] update --- .../tensorrt/tensorrt_execution_provider.cc | 30 +++++---- onnxruntime/core/session/inference_session.cc | 64 +++++++++++-------- 2 files changed, 55 insertions(+), 39 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 6a3f96c53f56a..b8c968f9d8008 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -902,10 +902,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } #if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 - if (cuda_graph_enable_) { - cuda_graph_ = std::make_unique(); - cuda_graph_->SetStream(stream_); - } + cuda_graph_ = std::make_unique(); #endif /* @@ -1117,10 +1114,6 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTransfer() cons } Status TensorrtExecutionProvider::OnRunStart() { - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - CaptureBegin(); - } return Status::OK(); } @@ -2269,6 +2262,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(state); @@ -2863,6 +2857,17 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsetDeviceMemory((*context_memory).get()); } +#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + cuda_graph_->SetStream(stream); + CaptureBegin(); + } +#endif + // Run TRT inference if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); @@ -2891,10 +2896,12 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector= 10000 + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture, + // another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc, // which might end up with many cuda graphs are captured by multiple threads if run with multithreading. - // OnRunStart() and ExecuteGraph() are synchronized inside Run(), therefore it's safe to start/end CUDA graph capture in OnRunStart()/compute_func() here. + // It's safe to start/end CUDA graph capture in compute_func() here since the whole fucntion is protected by the lock_guard(). if (cuda_graph_enable_ && !IsGraphCaptured()) { if (IsGraphCaptureAllowed()) { CaptureEnd(); @@ -2905,6 +2912,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorIsGraphCaptureEnabled()) { - if (cuda_ep->IsGraphCaptureEnabled()) { - if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as the model has control flow nodes which can't be supported by CUDA Graphs."; - - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as the model has control flow nodes which can't be supported by CUDA Graphs.")); - } else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all the graph nodes have not been partitioned to the CUDA EP."; - - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all the graph nodes have not been partitioned to the CUDA EP.")); - - } else { - LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; - cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep); - } + + if (trt_ep) { + target_ep = trt_ep; + } else if (cuda_ep) { + target_ep = cuda_ep; + } + + if (target_ep && target_ep->IsGraphCaptureEnabled()) { + if (HasControlflowNodes(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + " as the model has control flow nodes which can't be supported by CUDA Graphs.")); + } else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, target_ep->Type())) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << "as all the graph nodes have not been assigned to " + << target_ep->Type(); + + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + "as all the graph nodes have not been assigned to " + target_ep->Type())); + } else { + LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); } } From 7a38074d3dd82d702801faf1271e0ab524e4acfc Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 24 May 2023 16:57:21 +0000 Subject: [PATCH 04/15] update --- .../tensorrt/tensorrt_execution_provider.cc | 15 ++-- onnxruntime/core/session/inference_session.cc | 69 +++++++++---------- 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index b8c968f9d8008..7ebb255ff9c10 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -902,7 +902,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } #if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 - cuda_graph_ = std::make_unique(); + if (cuda_graph_enable_) { + cuda_graph_ = std::make_unique(); + } #endif /* @@ -1020,7 +1022,7 @@ Status TensorrtExecutionProvider::ReplayGraph() { ORT_ENFORCE(IsGraphCaptured()); // Please note that CUDAGraph::Replay() is not thread safe. // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), - // therefore calling CUDAGraph::Replay() is guaranteed to be thread safe. + // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. return cuda_graph_->Replay(); } @@ -2262,7 +2264,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(state); @@ -2897,10 +2898,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector= 10000 - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture, - // another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc, - // which might end up with many cuda graphs are captured by multiple threads if run with multithreading. + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture + // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc, + // which might end up with many cuda graphs are captured by multiple threads if running with multithreading. // It's safe to start/end CUDA graph capture in compute_func() here since the whole fucntion is protected by the lock_guard(). if (cuda_graph_enable_ && !IsGraphCaptured()) { if (IsGraphCaptureAllowed()) { diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cfa7ae370d14b..c5612f626150e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1515,47 +1515,44 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently only CUDA EP and TRT EP are considered. + // Currently CUDA graph is only considered by CUDA EP and TRT EP. // If the CUDA EP is part of the providers list for this session AND // The CUDA EP is configured to do a graph capture AND // All the graph nodes have been assigned to the CUDA EP, // Then the CUDA EP is cached for triggering a ReplayGraph() in Run(). // Same logic is applied to TRT EP. - IExecutionProvider* target_ep = nullptr; - auto* trt_ep = execution_providers_.Get(onnxruntime::kTensorrtExecutionProvider); - auto* cuda_ep = execution_providers_.Get(onnxruntime::kCudaExecutionProvider); - - if (trt_ep) { - target_ep = trt_ep; - } else if (cuda_ep) { - target_ep = cuda_ep; - } - - if (target_ep && target_ep->IsGraphCaptureEnabled()) { - if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA Graphs."; - - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as the model has control flow nodes which can't be supported by CUDA Graphs.")); - } else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, target_ep->Type())) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as all the graph nodes have not been assigned to " - << target_ep->Type(); - - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - "as all the graph nodes have not been assigned to " + target_ep->Type())); - } else { - LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; - cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); + std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + + for (auto& it : cuda_graph_support_ep_list) { + auto* target_ep = execution_providers_.Get(it); + + if (target_ep && target_ep->IsGraphCaptureEnabled()) { + if (HasControlflowNodes(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + " as the model has control flow nodes which can't be supported by CUDA Graphs.")); + } else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, target_ep->Type())) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << "as all the graph nodes have not been assigned to " + << target_ep->Type(); + + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + "as all the graph nodes have not been assigned to " + target_ep->Type())); + } else { + LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); + break; // Make sure only one ep can run CUDA graph. + } } } From f38208787b6018cb8148472f35f22793cf6f5d80 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 24 May 2023 17:13:30 +0000 Subject: [PATCH 05/15] update --- .../onnxruntime_test_python_cudagraph.py | 82 ++++++++++--------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index 5dd927a566e81..1a203040c6cdf 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -33,46 +33,52 @@ def testOrtValueUpdateInPlace(self): # noqa: N802 ortvalue_gpu.update_inplace(x1) np.testing.assert_allclose(x1, ortvalue_gpu.numpy()) - - def testRunModelWithCudaGraph(self): # noqa: N802 - if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + def testSelectEPtoRunCudaGraph(self): + if "TensorrtExecutionProvider" in onnxrt.get_available_providers(): + providers = [("TensorrtExecutionProvider", {"trt_cuda_graph_enable": True})] + self.RunModelWithCudaGraph(providers) + elif "CUDAExecutionProvider" in onnxrt.get_available_providers(): providers = [("CUDAExecutionProvider", {"enable_cuda_graph": True})] - INPUT_SIZE = 1280 # noqa: N806 - x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] * INPUT_SIZE, dtype=np.float32) - y = np.array([[0.0], [0.0], [0.0]] * INPUT_SIZE, dtype=np.float32) - x_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(x, "cuda", 0) - y_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(y, "cuda", 0) - - session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers) - io_binding = session.io_binding() - - # Bind the input and output - io_binding.bind_ortvalue_input("X", x_ortvalue) - io_binding.bind_ortvalue_output("Y", y_ortvalue) - - # One regular run for the necessary memory allocation and cuda graph capturing - session.run_with_iobinding(io_binding) - expected_y = np.array([[5.0], [11.0], [17.0]] * INPUT_SIZE, dtype=np.float32) - np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) - - # After capturing, CUDA graph replay happens from this Run onwards - session.run_with_iobinding(io_binding) - np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) - - # Update input and then replay CUDA graph - x_ortvalue.update_inplace( - np.array( - [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]] * INPUT_SIZE, - dtype=np.float32, - ) - ) - session.run_with_iobinding(io_binding) - np.testing.assert_allclose( - np.array([[50.0], [110.0], [170.0]] * INPUT_SIZE, dtype=np.float32), - y_ortvalue.numpy(), - rtol=1e-05, - atol=1e-05, + self.RunModelWithCudaGraph(providers) + + def RunModelWithCudaGraph(self, providers): # noqa: N802 + INPUT_SIZE = 1280 # noqa: N806 + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] * INPUT_SIZE, dtype=np.float32) + y = np.array([[0.0], [0.0], [0.0]] * INPUT_SIZE, dtype=np.float32) + x_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(x, "cuda", 0) + y_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(y, "cuda", 0) + + onnxrt.set_default_logger_severity(0) + session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers) + io_binding = session.io_binding() + + # Bind the input and output + io_binding.bind_ortvalue_input("X", x_ortvalue) + io_binding.bind_ortvalue_output("Y", y_ortvalue) + + # One regular run for the necessary memory allocation and cuda graph capturing + session.run_with_iobinding(io_binding) + expected_y = np.array([[5.0], [11.0], [17.0]] * INPUT_SIZE, dtype=np.float32) + np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) + + # After capturing, CUDA graph replay happens from this Run onwards + session.run_with_iobinding(io_binding) + np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) + + # Update input and then replay CUDA graph + x_ortvalue.update_inplace( + np.array( + [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]] * INPUT_SIZE, + dtype=np.float32, ) + ) + session.run_with_iobinding(io_binding) + np.testing.assert_allclose( + np.array([[50.0], [110.0], [170.0]] * INPUT_SIZE, dtype=np.float32), + y_ortvalue.numpy(), + rtol=1e-05, + atol=1e-05, + ) if __name__ == "__main__": From 2e9c64677db16a410fd5b800c36adf1633024994 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 24 May 2023 22:17:21 +0000 Subject: [PATCH 06/15] fix bug --- cmake/onnxruntime_providers.cmake | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 0d1f5c4cc1145..67d8dbc00b005 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -729,6 +729,8 @@ if (onnxruntime_USE_TENSORRT) "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.h" "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_stream_handle.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_graph.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tensorrt_cc_srcs}) From 4d022bfe755ccc65a3ddfced83414160a3cf205b Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 24 May 2023 23:37:20 +0000 Subject: [PATCH 07/15] add inference test for TRT EP --- onnxruntime/test/shared_lib/test_inference.cc | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index f53a93225b524..d30579a02907c 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1740,10 +1740,25 @@ TEST(CApiTest, io_binding_cuda) { } #endif -#if defined(USE_CUDA) +#if defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, cuda_graph) { const auto& api = Ort::GetApi(); + Ort::SessionOptions session_options; +#if defined(USE_TENSORRT) + // Enable cuda graph in TRT provider option. + OrtTensorRTProviderOptionsV2* trt_options; + ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); + std::unique_ptr + rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); + std::vector keys{"trt_cuda_graph_enable"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options), + rel_trt_options.get()) == nullptr); +#else // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); @@ -1753,10 +1768,10 @@ TEST(CApiTest, cuda_graph) { std::vector values{"1"}; ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); - Ort::SessionOptions session_options; ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( static_cast(session_options), rel_cuda_options.get()) == nullptr); +#endif // Create IoBinding for inputs and outputs. Ort::Session session(*ort_env, MODEL_URI, session_options); From 772c8494ee56cad6a303938a2bc9c3bbd324e2bf Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 25 May 2023 17:25:56 +0000 Subject: [PATCH 08/15] fix format --- .../core/providers/tensorrt/tensorrt_execution_provider.cc | 2 +- .../core/providers/tensorrt/tensorrt_execution_provider.h | 6 +++--- onnxruntime/core/session/inference_session.cc | 5 +++-- .../test/python/onnxruntime_test_python_cudagraph.py | 1 + 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7ebb255ff9c10..3078d4a749177 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1004,7 +1004,7 @@ bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } -void TensorrtExecutionProvider::CaptureBegin() { +void TensorrtExecutionProvider::CaptureBegin() { cuda_graph_->Reset(); cuda_graph_->CaptureBegin(); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 7d2c8780267b4..2d029e554ad50 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -220,13 +220,13 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool timing_cache_enable_ = false; bool force_timing_cache_match_ = false; bool detailed_build_log_ = false; - bool cuda_graph_enable_ = false; + bool cuda_graph_enable_ = false; #if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 - std::unique_ptr cuda_graph_; // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph) + std::unique_ptr cuda_graph_; // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph) bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; - const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. #endif std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ea8883d17a945..0349d220f1811 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1553,11 +1553,12 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This session cannot use the CUDA Graph feature as requested by the user " - "as all the graph nodes have not been assigned to " + target_ep->Type())); + "as all the graph nodes have not been assigned to " + + target_ep->Type())); } else { LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); - break; // Make sure only one ep can run CUDA graph. + break; // Make sure only one ep can run CUDA graph. } } } diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index 1a203040c6cdf..d1d8105758ad0 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -33,6 +33,7 @@ def testOrtValueUpdateInPlace(self): # noqa: N802 ortvalue_gpu.update_inplace(x1) np.testing.assert_allclose(x1, ortvalue_gpu.numpy()) + def testSelectEPtoRunCudaGraph(self): if "TensorrtExecutionProvider" in onnxrt.get_available_providers(): providers = [("TensorrtExecutionProvider", {"trt_cuda_graph_enable": True})] From 4122d69956703c1a8052884a20f0fb8e62af69ec Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 25 May 2023 17:59:31 +0000 Subject: [PATCH 09/15] fix format --- .../test/python/onnxruntime_test_python_cudagraph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index d1d8105758ad0..6d978eebeb9a8 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -18,7 +18,7 @@ class TestInferenceSessionWithCudaGraph(unittest.TestCase): - def testOrtValueUpdateInPlace(self): # noqa: N802 + def test_ort_value_update_in_place(self): x0 = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) ortvalue_cpu = onnxrt.OrtValue.ortvalue_from_numpy(x0) np.testing.assert_allclose(x0, ortvalue_cpu.numpy()) @@ -34,7 +34,7 @@ def testOrtValueUpdateInPlace(self): # noqa: N802 ortvalue_gpu.update_inplace(x1) np.testing.assert_allclose(x1, ortvalue_gpu.numpy()) - def testSelectEPtoRunCudaGraph(self): + def test_select_ep_to_run_cuda_graph(self): if "TensorrtExecutionProvider" in onnxrt.get_available_providers(): providers = [("TensorrtExecutionProvider", {"trt_cuda_graph_enable": True})] self.RunModelWithCudaGraph(providers) @@ -42,7 +42,7 @@ def testSelectEPtoRunCudaGraph(self): providers = [("CUDAExecutionProvider", {"enable_cuda_graph": True})] self.RunModelWithCudaGraph(providers) - def RunModelWithCudaGraph(self, providers): # noqa: N802 + def run_model_with_cuda_graph(self, providers): INPUT_SIZE = 1280 # noqa: N806 x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] * INPUT_SIZE, dtype=np.float32) y = np.array([[0.0], [0.0], [0.0]] * INPUT_SIZE, dtype=np.float32) From 1bc69e1a702045a7cc9fe2adc8e383ee4cc84101 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 25 May 2023 21:05:38 +0000 Subject: [PATCH 10/15] fix bug --- onnxruntime/test/python/onnxruntime_test_python_cudagraph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index 6d978eebeb9a8..cd5fd0cb61983 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -37,10 +37,10 @@ def test_ort_value_update_in_place(self): def test_select_ep_to_run_cuda_graph(self): if "TensorrtExecutionProvider" in onnxrt.get_available_providers(): providers = [("TensorrtExecutionProvider", {"trt_cuda_graph_enable": True})] - self.RunModelWithCudaGraph(providers) + self.run_model_with_cuda_graph(providers) elif "CUDAExecutionProvider" in onnxrt.get_available_providers(): providers = [("CUDAExecutionProvider", {"enable_cuda_graph": True})] - self.RunModelWithCudaGraph(providers) + self.run_model_with_cuda_graph(providers) def run_model_with_cuda_graph(self, providers): INPUT_SIZE = 1280 # noqa: N806 From 8f6390ae0c38fcb625f40d59425765209c96da98 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 19 Jun 2023 19:18:21 +0000 Subject: [PATCH 11/15] fix typo --- .../core/providers/tensorrt/tensorrt_execution_provider.cc | 2 +- onnxruntime/core/session/inference_session.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 3078d4a749177..de38cd4289533 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2902,7 +2902,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector Date: Mon, 19 Jun 2023 19:19:23 +0000 Subject: [PATCH 12/15] fix format --- onnxruntime/core/session/inference_session.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 363a71534a205..8519c049dad57 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -176,7 +176,7 @@ std::pair AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { return std::make_pair(true, static_cast(shape_nodes.size())); } - + bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { for (const auto& node : graph.Nodes()) { const auto& node_provider = node.GetExecutionProviderType(); @@ -1612,8 +1612,8 @@ common::Status InferenceSession::Initialize() { ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This session cannot use the CUDA Graph feature as requested by the user " "as the model has control flow nodes which can't be supported by CUDA Graphs.")); - } - + } + if (strcmp(target_ep->Type(), onnxruntime::kCudaExecutionProvider) == 0) { auto res = AreAllComputeNodesAssignedToCudaEp(graph); @@ -1628,7 +1628,7 @@ common::Status InferenceSession::Initialize() { "This session cannot use the CUDA Graph feature as requested by the user " " as all compute graph nodes have not been partitioned to the CUDA EP.")); } - + if (res.second > 0) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " << "Use the CUDA Graph feature with caution. " @@ -1653,7 +1653,7 @@ common::Status InferenceSession::Initialize() { target_ep->Type())); } } - + LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); break; // Make sure only one ep can run CUDA graph. From 23734d5c4d6aa4d95a6679409cd6b7f05e2c7f9c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 19 Jun 2023 20:22:25 +0000 Subject: [PATCH 13/15] remove cuda version macro --- .../providers/tensorrt/tensorrt_execution_provider.cc | 8 -------- .../core/providers/tensorrt/tensorrt_execution_provider.h | 6 ------ 2 files changed, 14 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index de38cd4289533..1ef26e901cd73 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -901,11 +901,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } -#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 if (cuda_graph_enable_) { cuda_graph_ = std::make_unique(); } -#endif /* * Parse explicit min/max/opt profile shapes from provider options. @@ -995,7 +993,6 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); } -#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const { return cuda_graph_enable_; } @@ -1032,7 +1029,6 @@ void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { // therefore following increment is guaranteed to be thread safe. ++regular_run_count_before_graph_capture_; } -#endif AllocatorPtr TensorrtExecutionProvider::GetAllocator(OrtMemType mem_type) const { if (mem_type == OrtMemTypeDefault) { @@ -2858,7 +2854,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsetDeviceMemory((*context_memory).get()); } -#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. @@ -2867,7 +2862,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorSetStream(stream); CaptureBegin(); } -#endif // Run TRT inference if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { @@ -2897,7 +2891,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector= 10000 // End CUDA graph capture. // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc, @@ -2913,7 +2906,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector= 10000 bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured() const override; Status ReplayGraph() override; -#endif private: TensorrtExecutionProviderInfo info_; @@ -222,12 +220,10 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool detailed_build_log_ = false; bool cuda_graph_enable_ = false; -#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 std::unique_ptr cuda_graph_; // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph) bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. -#endif std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; std::unordered_map> parsers_; @@ -279,11 +275,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { /**Check whether all the nodes of subgraph are supported*/ bool IsSubGraphFullySupported(SubGraphCollection_t supported_nodes_vector, const int number_of_ort_nodes) const; -#if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 bool IsGraphCaptureAllowed() const; void CaptureBegin(); void CaptureEnd(); void IncrementRegularRunCountBeforeGraphCapture(); -#endif }; } // namespace onnxruntime From 2f0d48b859cdb5c335715276ca203564acee63f4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 19 Jun 2023 20:22:35 +0000 Subject: [PATCH 14/15] fix bug --- onnxruntime/core/session/inference_session.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 8519c049dad57..27a1c3a833d01 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1614,7 +1614,7 @@ common::Status InferenceSession::Initialize() { "as the model has control flow nodes which can't be supported by CUDA Graphs.")); } - if (strcmp(target_ep->Type(), onnxruntime::kCudaExecutionProvider) == 0) { + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { auto res = AreAllComputeNodesAssignedToCudaEp(graph); if (!res.first) { From 3ae889ecab3e46647b128918c59ba145e869935b Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 19 Jun 2023 22:15:34 +0000 Subject: [PATCH 15/15] modify comments per reviewer --- onnxruntime/core/session/inference_session.cc | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 27a1c3a833d01..a777ac9e78219 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1591,11 +1591,18 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // + // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND // The CUDA EP is configured to do a graph capture AND - // All the compute graph nodes have been assigned to the CUDA EP, + // All the "compute" graph nodes have been assigned to the CUDA EP, // Then the CUDA EP is cached for triggering a ReplayGraph() in Run(). - // Similar logic is applied to TRT EP. + // + // Check for TRT EP: + // If the TRT EP is part of the providers list for this session AND + // The TRT EP is configured to do a graph capture AND + // All the graph nodes have been assigned to the TRT EP, + // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; for (auto& it : cuda_graph_support_ep_list) { @@ -1639,6 +1646,7 @@ common::Status InferenceSession::Initialize() { << "it is safe to use the CUDA Graph feature."; } } else { + // Following code path is for TRT EP currently. if (!AreAllNodesInMainGraphAssignedToOneEp(graph, target_ep->Type())) { LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " << "as all the graph nodes have not been assigned to "