Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change shared_ptr to unique_ptr to make the ownership clear #22209

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ static bool EpSharedContextsHasAllGraphs(const onnxruntime::GraphViewer& graph_v

if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) {
const std::string& graph_name = node.Name();
auto shared_qnn_model = SharedContext::GetInstance().GetSharedQnnModel(graph_name);
if (nullptr == shared_qnn_model) {
bool has_shared_qnn_model = SharedContext::GetInstance().HasQnnModel(graph_name);
if (!has_shared_qnn_model) {
LOGS(logger, VERBOSE) << "Graph: " << graph_name << " from EpContext node not found from shared EP contexts.";
return false;
}
Expand All @@ -566,8 +566,8 @@ static bool EpSharedContextsHasAllGraphs(const std::vector<IExecutionProvider::F
std::string cache_source = node_helper.Get(qnn::SOURCE, "");

const std::string& graph_name = ep_context_node->Name();
auto shared_qnn_model = SharedContext::GetInstance().GetSharedQnnModel(graph_name);
if (nullptr == shared_qnn_model) {
bool has_shared_qnn_model = SharedContext::GetInstance().HasQnnModel(graph_name);
if (!has_shared_qnn_model) {
LOGS(logger, VERBOSE) << "Graph: " << graph_name << " from EpContext node not found from shared EP contexts.";
return false;
}
Expand Down Expand Up @@ -776,10 +776,6 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector<NodeComputeInfo>& nod
NodeComputeInfo compute_info;
compute_info.create_state_func = [&](ComputeContext* context, FunctionState* state) {
LOGS(logger, VERBOSE) << "compute_info.create_state_func context->node_name: " << context->node_name;
if (use_shared_model_) {
*state = qnn_models_shared_[context->node_name].get();
return 0;
}
*state = qnn_models_[context->node_name].get();
return 0;
};
Expand Down Expand Up @@ -895,8 +891,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
ORT_RETURN_IF(nullptr == qnn_model_shared, "Graph: " + key + " not found from shared EP contexts.");
ORT_RETURN_IF_ERROR(qnn_model_shared->SetGraphInputOutputInfo(graph_viewer, fused_node, logger));
ORT_RETURN_IF_ERROR(qnn_model_shared->SetupQnnInputOutput(logger));
qnn_models_shared_.emplace(graph_meta_id, qnn_model_shared);
use_shared_model_ = true;
qnn_models_.emplace(graph_meta_id, std::move(qnn_model_shared));
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
}
return Status::OK();
Expand Down Expand Up @@ -940,12 +935,12 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
}

if (share_ep_contexts_ && qnn_models.size() > 0) {
std::vector<std::shared_ptr<qnn::QnnModel>> shared_qnn_models;
std::vector<std::unique_ptr<qnn::QnnModel>> shared_qnn_models;
for (auto& [key, value] : qnn_models) {
shared_qnn_models.push_back(std::move(qnn_models[key]));
}
std::string duplicate_graph_names;
bool has_duplicate_graph = SharedContext::GetInstance().SetSharedQnnModel(shared_qnn_models,
bool has_duplicate_graph = SharedContext::GetInstance().SetSharedQnnModel(std::move(shared_qnn_models),
duplicate_graph_names);
ORT_RETURN_IF(has_duplicate_graph, "Duplicate graph names detect across sessions: " + duplicate_graph_names);
}
Expand Down
24 changes: 15 additions & 9 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,34 @@
return !shared_qnn_models_.empty();
}

std::shared_ptr<qnn::QnnModel> GetSharedQnnModel(const std::string& model_name) {
bool HasQnnModel(const std::string& model_name) {
auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(),
[&model_name](const std::unique_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });

Check warning on line 40 in onnxruntime/core/providers/qnn/qnn_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/qnn_execution_provider.h:40: Lines should be <= 120 characters long [whitespace/line_length] [2]
return it != shared_qnn_models_.end();
}

std::unique_ptr<qnn::QnnModel> GetSharedQnnModel(const std::string& model_name) {
const std::lock_guard<OrtMutex> lock(mtx_);
auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(),
[&model_name](const std::shared_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });
[&model_name](const std::unique_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });

Check warning on line 47 in onnxruntime/core/providers/qnn/qnn_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/qnn_execution_provider.h:47: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (it == shared_qnn_models_.end()) {
return nullptr;
}
return *it;
auto qnn_model = std::move(*it);
shared_qnn_models_.erase(it);
return qnn_model;
}

bool SetSharedQnnModel(std::vector<std::shared_ptr<qnn::QnnModel>>& shared_qnn_models,
bool SetSharedQnnModel(std::vector<std::unique_ptr<qnn::QnnModel>>&& shared_qnn_models,
std::string& duplicate_graph_names) {
const std::lock_guard<OrtMutex> lock(mtx_);
bool graph_exist = false;
for (auto& shared_qnn_model : shared_qnn_models) {
auto& model_name = shared_qnn_model->Name();
auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(),
[&model_name](const std::shared_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });
[&model_name](const std::unique_ptr<qnn::QnnModel>& qnn_model) { return qnn_model->Name() == model_name; });

Check warning on line 63 in onnxruntime/core/providers/qnn/qnn_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/qnn/qnn_execution_provider.h:63: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (it == shared_qnn_models_.end()) {
shared_qnn_models_.push_back(shared_qnn_model);
shared_qnn_models_.push_back(std::move(shared_qnn_model));

Check warning on line 65 in onnxruntime/core/providers/qnn/qnn_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/qnn_execution_provider.h:65: Add #include <utility> for move [build/include_what_you_use] [4]
} else {
duplicate_graph_names.append(model_name + " ");
graph_exist = true;
Expand All @@ -70,7 +78,7 @@
SharedContext(const SharedContext&) = delete;
SharedContext& operator=(const SharedContext&) = delete;

std::vector<std::shared_ptr<qnn::QnnModel>> shared_qnn_models_;
std::vector<std::unique_ptr<qnn::QnnModel>> shared_qnn_models_;
// Producer sessions can be in parallel
// Consumer sessions have to be after producer sessions initialized
OrtMutex mtx_;
Expand Down Expand Up @@ -128,8 +136,6 @@
qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault;
std::unique_ptr<qnn::QnnBackendManager> qnn_backend_manager_;
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>> qnn_models_;
std::unordered_map<std::string, std::shared_ptr<qnn::QnnModel>> qnn_models_shared_;
bool use_shared_model_ = false;
bool context_cache_enabled_ = false;
std::string context_cache_path_cfg_ = "";
std::string context_node_name_prefix_ = "";
Expand Down
Loading