From c7c1e826d9f5d7fea7d634f9e7d191f318351ce3 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 14 Oct 2024 17:10:11 -0700 Subject: [PATCH] support session options "optimizedModelFilePath" and "extra" --- js/node/src/session_options_helper.cc | 38 +++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 9e4f6c23d5037..d560011391a7e 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -122,6 +122,22 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess } } +void IterateExtraOptions(const std::string& prefix, const Napi::Object& obj, Ort::SessionOptions& sessionOptions) { + for (const auto& kvp : obj) { + auto key = kvp.first.As().Utf8Value(); + Napi::Value value = kvp.second; + if (value.IsObject()) { + IterateExtraOptions(prefix + key + ".", value.As(), sessionOptions); + } else { + ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString(), obj.Env(), + "Invalid argument: sessionOptions.extra value must be a string in Node.js binding."); + std::string entry = prefix + key; + auto val = value.As().Utf8Value(); + sessionOptions.AddConfigEntry(entry.c_str(), val.c_str()); + } + } +} + void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions) { // Execution provider if (options.Has("executionProviders")) { @@ -189,6 +205,28 @@ void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessio } } + // optimizedModelFilePath + if (options.Has("optimizedModelFilePath")) { + auto optimizedModelFilePathValue = options.Get("optimizedModelFilePath"); + ORT_NAPI_THROW_TYPEERROR_IF(!optimizedModelFilePathValue.IsString(), options.Env(), + "Invalid argument: sessionOptions.optimizedModelFilePath must be a string."); +#ifdef _WIN32 + auto str = optimizedModelFilePathValue.As().Utf16Value(); + std::basic_string optimizedModelFilePath = std::wstring{str.begin(), str.end()}; +#else + std::basic_string optimizedModelFilePath = optimizedModelFilePathValue.As().Utf8Value(); +#endif + sessionOptions.SetOptimizedModelFilePath(optimizedModelFilePath.c_str()); + } + + // extra + if (options.Has("extra")) { + auto extraValue = options.Get("extra"); + ORT_NAPI_THROW_TYPEERROR_IF(!extraValue.IsObject(), options.Env(), + "Invalid argument: sessionOptions.extra must be an object."); + IterateExtraOptions("", extraValue.As(), sessionOptions); + } + // execution mode if (options.Has("executionMode")) { auto executionModeValue = options.Get("executionMode");