diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 698ceaea7c3b7..24132b98e3757 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -546,8 +546,8 @@ static bool EpSharedContextsHasAllGraphs(const onnxruntime::GraphViewer& graph_v if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) { const std::string& graph_name = node.Name(); - auto shared_qnn_model = SharedContext::GetInstance().GetSharedQnnModel(graph_name); - if (nullptr == shared_qnn_model) { + bool has_shared_qnn_model = SharedContext::GetInstance().HasQnnModel(graph_name); + if (!has_shared_qnn_model) { LOGS(logger, VERBOSE) << "Graph: " << graph_name << " from EpContext node not found from shared EP contexts."; return false; } @@ -566,8 +566,8 @@ static bool EpSharedContextsHasAllGraphs(const std::vectorName(); - auto shared_qnn_model = SharedContext::GetInstance().GetSharedQnnModel(graph_name); - if (nullptr == shared_qnn_model) { + bool has_shared_qnn_model = SharedContext::GetInstance().HasQnnModel(graph_name); + if (!has_shared_qnn_model) { LOGS(logger, VERBOSE) << "Graph: " << graph_name << " from EpContext node not found from shared EP contexts."; return false; } @@ -776,10 +776,6 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod NodeComputeInfo compute_info; compute_info.create_state_func = [&](ComputeContext* context, FunctionState* state) { LOGS(logger, VERBOSE) << "compute_info.create_state_func context->node_name: " << context->node_name; - if (use_shared_model_) { - *state = qnn_models_shared_[context->node_name].get(); - return 0; - } *state = qnn_models_[context->node_name].get(); return 0; }; @@ -895,8 +891,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF(nullptr == qnn_model_shared, "Graph: " + key + " not found from shared EP contexts."); ORT_RETURN_IF_ERROR(qnn_model_shared->SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); ORT_RETURN_IF_ERROR(qnn_model_shared->SetupQnnInputOutput(logger)); - qnn_models_shared_.emplace(graph_meta_id, qnn_model_shared); - use_shared_model_ = true; + qnn_models_.emplace(graph_meta_id, std::move(qnn_model_shared)); ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); } return Status::OK(); @@ -940,12 +935,12 @@ Status QNNExecutionProvider::Compile(const std::vector& fused } if (share_ep_contexts_ && qnn_models.size() > 0) { - std::vector> shared_qnn_models; + std::vector> shared_qnn_models; for (auto& [key, value] : qnn_models) { shared_qnn_models.push_back(std::move(qnn_models[key])); } std::string duplicate_graph_names; - bool has_duplicate_graph = SharedContext::GetInstance().SetSharedQnnModel(shared_qnn_models, + bool has_duplicate_graph = SharedContext::GetInstance().SetSharedQnnModel(std::move(shared_qnn_models), duplicate_graph_names); ORT_RETURN_IF(has_duplicate_graph, "Duplicate graph names detect across sessions: " + duplicate_graph_names); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index ac9098f907975..e0eaf31c94a36 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -35,26 +35,34 @@ class SharedContext { return !shared_qnn_models_.empty(); } - std::shared_ptr GetSharedQnnModel(const std::string& model_name) { + bool HasQnnModel(const std::string& model_name) { + auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + return it != shared_qnn_models_.end(); + } + + std::unique_ptr GetSharedQnnModel(const std::string& model_name) { const std::lock_guard lock(mtx_); auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::shared_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); if (it == shared_qnn_models_.end()) { return nullptr; } - return *it; + auto qnn_model = std::move(*it); + shared_qnn_models_.erase(it); + return qnn_model; } - bool SetSharedQnnModel(std::vector>& shared_qnn_models, + bool SetSharedQnnModel(std::vector>&& shared_qnn_models, std::string& duplicate_graph_names) { const std::lock_guard lock(mtx_); bool graph_exist = false; for (auto& shared_qnn_model : shared_qnn_models) { auto& model_name = shared_qnn_model->Name(); auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::shared_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); if (it == shared_qnn_models_.end()) { - shared_qnn_models_.push_back(shared_qnn_model); + shared_qnn_models_.push_back(std::move(shared_qnn_model)); } else { duplicate_graph_names.append(model_name + " "); graph_exist = true; @@ -70,7 +78,7 @@ class SharedContext { SharedContext(const SharedContext&) = delete; SharedContext& operator=(const SharedContext&) = delete; - std::vector> shared_qnn_models_; + std::vector> shared_qnn_models_; // Producer sessions can be in parallel // Consumer sessions have to be after producer sessions initialized OrtMutex mtx_; @@ -128,8 +136,6 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; std::unique_ptr qnn_backend_manager_; std::unordered_map> qnn_models_; - std::unordered_map> qnn_models_shared_; - bool use_shared_model_ = false; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; std::string context_node_name_prefix_ = "";