Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QNN EP option context_node_name_prefix to set EPContext node name prefix #21236

Merged
merged 8 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,17 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_;
}

// For the case that workaround QNN context PD memory limit, user need split the model into pieces and
// generate the QNN context model separately.
// It could happen that the generated EPContext node in separate graph has same node name.
// User can set this context_node_name_prefix for each split pieces to avoid that happens.
static const std::string QNN_CONTEXT_NODE_NAME_PREFIX = "context_node_name_prefix";
auto context_node_name_prefix_pos = provider_options_map.find(QNN_CONTEXT_NODE_NAME_PREFIX);
if (context_node_name_prefix_pos != provider_options_map.end()) {
context_node_name_prefix_ = context_node_name_prefix_pos->second;
LOGS_DEFAULT(VERBOSE) << "User specified QNN context node name prefix: " << context_node_name_prefix_;
}

qnn_backend_manager_ = std::make_unique<qnn::QnnBackendManager>(
std::move(backend_path),
profiling_level_etw,
Expand Down Expand Up @@ -612,7 +623,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
const auto gen_metadef_name = [&]() {
uint64_t model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
return MakeString(QNN, "_", model_hash, "_", metadef_id);
return MakeString(QNN, context_node_name_prefix_, "_", model_hash, "_", metadef_id);
};

// For model with EPContext, make sure each partition only has one single EPContext node
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class QNNExecutionProvider : public IExecutionProvider {
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>> qnn_models_;
bool context_cache_enabled_ = false;
std::string context_cache_path_cfg_ = "";
std::string context_node_name_prefix_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
bool qnn_context_embed_mode_ = true;
int32_t vtcm_size_in_mb_ = 0;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
"\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n"
"\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n"
"\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n"
"\t [QNN only] [context_node_name_prefix]: QNN EPContext node name prefix to make sure the generated node name unique for split models. e.g 'phi3_part1'.\n"

Check warning on line 100 in onnxruntime/test/perftest/command_args_parser.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/perftest/command_args_parser.cc:100: Lines should be <= 120 characters long [whitespace/line_length] [2]
HectorSVC marked this conversation as resolved.
Show resolved Hide resolved
"\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n"
"\n"
"\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n"
Expand Down
39 changes: 39 additions & 0 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,45 @@
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
}

TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
std::string node_name_prefix = "node_name_prefix_test";
provider_options["context_node_name_prefix"] = node_name_prefix;

// Add kMSDomain to cover contrib op like Gelu
const std::unordered_map<std::string, int> domain_to_version = {{"", 13}, {kMSDomain, 1}};

auto& logging_manager = DefaultLoggingManager();
logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR);

const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx";
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str());
so.AppendExecutionProvider("QNN", provider_options);

Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so);

// Make sure the Qnn context cache binary file is generated
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));

std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger()));

Check warning on line 310 in onnxruntime/test/providers/qnn/qnn_ep_context_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/providers/qnn/qnn_ep_context_test.cc:310: Lines should be <= 120 characters long [whitespace/line_length] [2]
for (auto& node : model->MainGraph().Nodes()) {
if (node.OpType() == "EPContext") {
EXPECT_TRUE(node.Name().find(node_name_prefix) != std::string::npos);
}
}

// clean up
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
}

// Run QDQ model on HTP 3 times
// 1st run will generate the Qnn context cache onnx file
// 2nd run directly loads and run from Qnn context cache model
Expand Down
Loading