Skip to content

Commit

Permalink
change shared_ptr to unique_ptr to make the ownership clear
Browse files Browse the repository at this point in the history
  • Loading branch information
HectorSVC committed Sep 24, 2024
1 parent 7811839 commit bc73d0f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
19 changes: 7 additions & 12 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ static bool EpSharedContextsHasAllGraphs(const onnxruntime::GraphViewer& graph_v

if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) {
const std::string& graph_name = node.Name();
auto shared_qnn_model = SharedContext::GetInstance().GetSharedQnnModel(graph_name);
if (nullptr == shared_qnn_model) {
bool has_shared_qnn_model = SharedContext::GetInstance().HasQnnModel(graph_name);
if (!has_shared_qnn_model) {
LOGS(logger, VERBOSE) << "Graph: " << graph_name << " from EpContext node not found from shared EP contexts.";
return false;
}
Expand All @@ -566,8 +566,8 @@ static bool EpSharedContextsHasAllGraphs(const std::vector<IExecutionProvider::F
std::string cache_source = node_helper.Get(qnn::SOURCE, "");

const std::string& graph_name = ep_context_node->Name();
auto shared_qnn_model = SharedContext::GetInstance().GetSharedQnnModel(graph_name);
if (nullptr == shared_qnn_model) {
bool has_shared_qnn_model = SharedContext::GetInstance().HasQnnModel(graph_name);
if (!has_shared_qnn_model) {
LOGS(logger, VERBOSE) << "Graph: " << graph_name << " from EpContext node not found from shared EP contexts.";
return false;
}
Expand Down Expand Up @@ -776,10 +776,6 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector<NodeComputeInfo>& nod
NodeComputeInfo compute_info;
compute_info.create_state_func = [&](ComputeContext* context, FunctionState* state) {
LOGS(logger, VERBOSE) << "compute_info.create_state_func context->node_name: " << context->node_name;
if (use_shared_model_) {
*state = qnn_models_shared_[context->node_name].get();
return 0;
}
*state = qnn_models_[context->node_name].get();
return 0;
};
Expand Down Expand Up @@ -895,8 +891,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
ORT_RETURN_IF(nullptr == qnn_model_shared, "Graph: " + key + " not found from shared EP contexts.");
ORT_RETURN_IF_ERROR(qnn_model_shared->SetGraphInputOutputInfo(graph_viewer, fused_node, logger));
ORT_RETURN_IF_ERROR(qnn_model_shared->SetupQnnInputOutput(logger));
qnn_models_shared_.emplace(graph_meta_id, qnn_model_shared);
use_shared_model_ = true;
qnn_models_.emplace(graph_meta_id, std::move(qnn_model_shared));
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
}
return Status::OK();
Expand Down Expand Up @@ -940,12 +935,12 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
}

if (share_ep_contexts_ && qnn_models.size() > 0) {
std::vector<std::shared_ptr<qnn::QnnModel>> shared_qnn_models;
std::vector<std::unique_ptr<qnn::QnnModel>> shared_qnn_models;
for (auto& [key, value] : qnn_models) {
shared_qnn_models.push_back(std::move(qnn_models[key]));
}
std::string duplicate_graph_names;
bool has_duplicate_graph = SharedContext::GetInstance().SetSharedQnnModel(shared_qnn_models,
bool has_duplicate_graph = SharedContext::GetInstance().SetSharedQnnModel(std::move(shared_qnn_models),
duplicate_graph_names);
ORT_RETURN_IF(has_duplicate_graph, "Duplicate graph names detect across sessions: " + duplicate_graph_names);
}
Expand Down
24 changes: 15 additions & 9 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,34 @@ class SharedContext {
return !shared_qnn_models_.empty();
}

std::shared_ptr<qnn::QnnModel> GetSharedQnnModel(const std::string& model_name) {
bool HasQnnModel(const std::string& model_name) {
auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(),
[&model_name](const std::unique_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });

Check warning on line 40 in onnxruntime/core/providers/qnn/qnn_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/qnn_execution_provider.h:40: Lines should be <= 120 characters long [whitespace/line_length] [2]
return it != shared_qnn_models_.end();
}

std::unique_ptr<qnn::QnnModel> GetSharedQnnModel(const std::string& model_name) {
const std::lock_guard<OrtMutex> lock(mtx_);
auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(),
[&model_name](const std::shared_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });
[&model_name](const std::unique_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });

Check warning on line 47 in onnxruntime/core/providers/qnn/qnn_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/qnn_execution_provider.h:47: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (it == shared_qnn_models_.end()) {
return nullptr;
}
return *it;
auto qnn_model = std::move(*it);
shared_qnn_models_.erase(it);
return qnn_model;
}

bool SetSharedQnnModel(std::vector<std::shared_ptr<qnn::QnnModel>>& shared_qnn_models,
bool SetSharedQnnModel(std::vector<std::unique_ptr<qnn::QnnModel>>&& shared_qnn_models,
std::string& duplicate_graph_names) {
const std::lock_guard<OrtMutex> lock(mtx_);
bool graph_exist = false;
for (auto& shared_qnn_model : shared_qnn_models) {
auto& model_name = shared_qnn_model->Name();
auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(),
[&model_name](const std::shared_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });
[&model_name](const std::unique_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });

Check warning on line 63 in onnxruntime/core/providers/qnn/qnn_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/qnn_execution_provider.h:63: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (it == shared_qnn_models_.end()) {
shared_qnn_models_.push_back(shared_qnn_model);
shared_qnn_models_.push_back(std::move(shared_qnn_model));

Check warning on line 65 in onnxruntime/core/providers/qnn/qnn_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/qnn_execution_provider.h:65: Add #include <utility> for move [build/include_what_you_use] [4]
} else {
duplicate_graph_names.append(model_name + " ");
graph_exist = true;
Expand All @@ -70,7 +78,7 @@ class SharedContext {
SharedContext(const SharedContext&) = delete;
SharedContext& operator=(const SharedContext&) = delete;

std::vector<std::shared_ptr<qnn::QnnModel>> shared_qnn_models_;
std::vector<std::unique_ptr<qnn::QnnModel>> shared_qnn_models_;
// Producer sessions can be in parallel
// Consumer sessions have to be after producer sessions initialized
OrtMutex mtx_;
Expand Down Expand Up @@ -128,8 +136,6 @@ class QNNExecutionProvider : public IExecutionProvider {
qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
std::unique_ptr<qnn::QnnBackendManager> qnn_backend_manager_;
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>> qnn_models_;
std::unordered_map<std::string, std::shared_ptr<qnn::QnnModel>> qnn_models_shared_;
bool use_shared_model_ = false;
bool context_cache_enabled_ = false;
std::string context_cache_path_cfg_ = "";
std::string context_node_name_prefix_ = "";
Expand Down

0 comments on commit bc73d0f

Please sign in to comment.