From e034c53ab35595054d5814f3ae0b6ed5c1564795 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 15 Jan 2024 14:59:07 +0800 Subject: [PATCH 1/9] [Draft][WebNN EP] Test WebNN async API with Asyncify --- cmake/onnxruntime_webassembly.cmake | 2 +- onnxruntime/core/providers/webnn/builders/model_builder.cc | 7 ++++++- .../core/providers/webnn/webnn_execution_provider.cc | 6 +++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 546d50c1ca2d3..729556756cb71 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -268,7 +268,7 @@ else() endif() if (onnxruntime_USE_WEBNN) - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT -sASYNCIFY=1 -sASYNCIFY_STACK_SIZE=65536") if (onnxruntime_DISABLE_RTTI) set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -fno-rtti -DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") endif() diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index cf8a0e23db43b..c52a9d26d32f5 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -386,7 +386,12 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { for (auto& name : output_names_) { named_operands.set(name, wnn_operands_.at(name)); } - emscripten::val wnn_graph = wnn_builder_.call("buildSync", named_operands); + + emscripten::val console = emscripten::val::global("console"); + console.call("log", emscripten::val("start webnn async build()...")); + emscripten::val wnn_graph = wnn_builder_.call("build", named_operands).await(); + console.call("log", wnn_builder_); + console.call("log", emscripten::val("Done webnn async build()...")); if (!wnn_graph.as()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 2922cf9540a8e..be5bd57178c32 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -42,7 +42,11 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (webnn_power_flags.compare("default") != 0) { context_options.set("powerPreference", emscripten::val(webnn_power_flags)); } - wnn_context_ = ml.call("createContextSync", context_options); + emscripten::val console = emscripten::val::global("console"); + console.call("log", emscripten::val("start webnn async createContext()...")); + wnn_context_ = ml.call("createContext", context_options).await(); + console.call("log", wnn_context_); + console.call("log", emscripten::val("Done webnn async createContext()...")); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } From a5dca52130a8b98925c644b6df0bd8b4fb71c914 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 17 Jan 2024 20:58:56 +0800 Subject: [PATCH 2/9] Try to register async _OrtCreateSession in jsep --- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 2 +- js/web/script/build.ts | 2 +- onnxruntime/wasm/js_internal_api.js | 4 ++++ 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 9d4d5875310b7..6d21d4fb5b8a7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; - _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number; + _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; _OrtReleaseSession(sessionHandle: number): void; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5821fac3c468f..8243ebbbe1e41 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -228,7 +228,7 @@ export const createSession = async( await Promise.all(loadingPromises); } - sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } diff --git a/js/web/script/build.ts b/js/web/script/build.ts index ea0c122cb51de..eac7a7efd405e 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -397,7 +397,7 @@ async function main() { // ort.webgpu[.min].js await addAllWebBuildTasks({ outputBundleName: 'ort.webgpu', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'}, + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'false'}, }); // ort.wasm[.min].js await addAllWebBuildTasks({ diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 25ece9c700d5d..dc7fca902e01b 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -160,6 +160,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea }; // replace the original functions with asyncified versions + Module['_OrtCreateSession'] = runAsync(jsepWrapAsync( + Module['_OrtCreateSession'], + () => Module['_OrtCreateSession'], + v => Module['_OrtCreateSession'] = v)); Module['_OrtRun'] = runAsync(jsepWrapAsync( Module['_OrtRun'], () => Module['_OrtRun'], From 649e1f538157f6a8105cd8d41e90facbd2c2f419 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 18 Jan 2024 09:32:57 +0800 Subject: [PATCH 3/9] Addressed comments and added debug log --- js/web/script/build.ts | 2 +- onnxruntime/wasm/js_internal_api.js | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/js/web/script/build.ts b/js/web/script/build.ts index eac7a7efd405e..374d6a1c22194 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -397,7 +397,7 @@ async function main() { // ort.webgpu[.min].js await addAllWebBuildTasks({ outputBundleName: 'ort.webgpu', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'false'}, + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, }); // ort.wasm[.min].js await addAllWebBuildTasks({ diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index dc7fca902e01b..e15ab0872f2eb 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -93,6 +93,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea // const jsepWrapAsync = (func, getFunc, setFunc) => { return (...args) => { + console.log('log from jsepWrapAsync start'); // cache the async data before calling the function. const previousAsync = Asyncify.currData; @@ -115,6 +116,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea // returns the promise return Asyncify.whenDone(); } + console.log('log from jsepWrapAsync end'); // the function is synchronous. returns the result. return ret; }; @@ -160,10 +162,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea }; // replace the original functions with asyncified versions - Module['_OrtCreateSession'] = runAsync(jsepWrapAsync( + Module['_OrtCreateSession'] = jsepWrapAsync( Module['_OrtCreateSession'], () => Module['_OrtCreateSession'], - v => Module['_OrtCreateSession'] = v)); + v => Module['_OrtCreateSession'] = v); Module['_OrtRun'] = runAsync(jsepWrapAsync( Module['_OrtRun'], () => Module['_OrtRun'], From 7f96b039ae4c9739f5e183d9def1886cf42544a8 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 18 Jan 2024 09:56:47 +0800 Subject: [PATCH 4/9] Let webnn be used in initJsep --- js/web/lib/wasm/wasm-core-impl.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 8243ebbbe1e41..01294595ce878 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise => { * @param epName */ export const initEp = async(env: Env, epName: string): Promise => { - if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { + if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) { // perform WebGPU availability check if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); From d8a5c92fc1b4bfeeb8f0f37678c97bcff5232d07 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 19 Jan 2024 08:40:59 +0800 Subject: [PATCH 5/9] Enable webnn async compute, remove debug log, remove unnecessary flags --- cmake/onnxruntime_webassembly.cmake | 2 +- js/web/lib/build-def.d.ts | 4 --- js/web/lib/index.ts | 4 +-- js/web/script/build.ts | 5 --- js/web/script/test-runner-cli-args.ts | 4 --- .../core/providers/webnn/builders/model.cc | 34 +++++++------------ .../providers/webnn/builders/model_builder.cc | 11 ------ onnxruntime/wasm/js_internal_api.js | 2 -- 8 files changed, 15 insertions(+), 51 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 729556756cb71..546d50c1ca2d3 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -268,7 +268,7 @@ else() endif() if (onnxruntime_USE_WEBNN) - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT -sASYNCIFY=1 -sASYNCIFY_STACK_SIZE=65536") + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") if (onnxruntime_DISABLE_RTTI) set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -fno-rtti -DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") endif() diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index b3868871a4753..2c9cd88a375bd 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -21,10 +21,6 @@ interface BuildDefinitions { /** * defines whether to disable the whole WebNN backend in the build. */ - readonly DISABLE_WEBNN: boolean; - /** - * defines whether to disable the whole WebAssembly backend in the build. - */ readonly DISABLE_WASM: boolean; /** * defines whether to disable proxy feature in WebAssembly backend in the build. diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index baf45e74addea..b212c0f49df3b 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) { require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU) { registerBackend('webgpu', wasmBackend, 5); + registerBackend('webnn', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (!BUILD_DEFS.DISABLE_WEBNN) { - registerBackend('webnn', wasmBackend, 9); - } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 374d6a1c22194..d3652f3820357 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // /js/ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WEBGL': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'false', - 'BUILD_DEFS.DISABLE_WEBNN': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'false', @@ -364,7 +363,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -411,7 +409,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WASM': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); // ort.wasm-core[.min].js @@ -421,7 +418,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -434,7 +430,6 @@ async function main() { 'BUILD_DEFS.DISABLE_TRAINING': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 8f6c5f6f04122..ed4dd76a6e315 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const globalEnvFlags = parseGlobalEnvFlags(args); - if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) { - throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.'); - } - // Options: // --log-verbose=<...> // --log-info=<...> diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index eaf549ef4e072..48487b9e25240 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -26,6 +26,11 @@ Model::~Model() {} Status Model::Predict(const InlinedHashMap& inputs, const InlinedHashMap& outputs) { + // Allocate the MLNamedArrayBufferViews for inputs and outputs at every compute() + // because they will be transferred after compute() done. + // https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution + AllocateInputOutputBuffers(); + for (const auto& input : inputs) { const std::string& name = input.first; const struct OnnxTensorData tensor = input.second; @@ -70,22 +75,13 @@ Status Model::Predict(const InlinedHashMap& inputs, "The input of graph has unsupported type, name: ", name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers. + // Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer. + // As Wasm ArrayBuffer is not detachable. wnn_inputs_[name].call("set", view); -#else - wnn_inputs_.set(name, view); -#endif } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // This vector uses for recording output buffers from WebNN graph compution when WebAssembly - // multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView, - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews - // and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory - // address is different from the non-shared one, additional memory copy is required here. InlinedHashMap output_views; -#endif + for (const auto& output : outputs) { const std::string& name = output.first; const struct OnnxTensorData tensor = output.second; @@ -131,21 +127,17 @@ Status Model::Predict(const InlinedHashMap& inputs, name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS output_views.insert({name, view}); -#else - wnn_outputs_.set(name, view); -#endif } - wnn_context_.call("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer. + emscripten::val results = wnn_context_.call("compute", wnn_graph_, wnn_inputs_, wnn_outputs_).await(); + + // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer. for (const auto& output : outputs) { const std::string& name = output.first; emscripten::val view = output_views.at(name); - view.call("set", wnn_outputs_[name]); + view.call("set", results["outputs"][name]); } -#endif + return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index c52a9d26d32f5..4abd244ba4648 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -387,11 +387,7 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { named_operands.set(name, wnn_operands_.at(name)); } - emscripten::val console = emscripten::val::global("console"); - console.call("log", emscripten::val("start webnn async build()...")); emscripten::val wnn_graph = wnn_builder_.call("build", named_operands).await(); - console.call("log", wnn_builder_); - console.call("log", emscripten::val("Done webnn async build()...")); if (!wnn_graph.as()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); } @@ -400,13 +396,6 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { model->SetOutputs(std::move(output_names_)); model->SetScalarOutputs(std::move(scalar_outputs_)); model->SetInputOutputInfo(std::move(input_output_info_)); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Pre-allocate the input and output tensors for the WebNN graph - // when WebAssembly multi-threads is enabled since WebNN API only - // accepts non-shared ArrayBufferView. - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews - model->AllocateInputOutputBuffers(); -#endif return Status::OK(); } diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index e15ab0872f2eb..a125391c492c4 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -93,7 +93,6 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea // const jsepWrapAsync = (func, getFunc, setFunc) => { return (...args) => { - console.log('log from jsepWrapAsync start'); // cache the async data before calling the function. const previousAsync = Asyncify.currData; @@ -116,7 +115,6 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea // returns the promise return Asyncify.whenDone(); } - console.log('log from jsepWrapAsync end'); // the function is synchronous. returns the result. return ret; }; From 1a7edbbe0c781df0a0aa688dcd0134bc3d67cda7 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 19 Jan 2024 09:16:47 +0800 Subject: [PATCH 6/9] Remove rest debug log --- onnxruntime/core/providers/webnn/webnn_execution_provider.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index be5bd57178c32..1eb98aa937876 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -43,10 +43,7 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f context_options.set("powerPreference", emscripten::val(webnn_power_flags)); } emscripten::val console = emscripten::val::global("console"); - console.call("log", emscripten::val("start webnn async createContext()...")); wnn_context_ = ml.call("createContext", context_options).await(); - console.call("log", wnn_context_); - console.call("log", emscripten::val("Done webnn async createContext()...")); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } From 76a55ab84914ab70ee96a7c9613d22ca96fdf96c Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 19 Jan 2024 09:18:35 +0800 Subject: [PATCH 7/9] Remove rest debug log --- onnxruntime/core/providers/webnn/webnn_execution_provider.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 1eb98aa937876..df7871614b267 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -42,7 +42,7 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (webnn_power_flags.compare("default") != 0) { context_options.set("powerPreference", emscripten::val(webnn_power_flags)); } - emscripten::val console = emscripten::val::global("console"); + wnn_context_ = ml.call("createContext", context_options).await(); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); From 47de0706fd693e066132b7059e6b0f99a7e7e36a Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 22 Jan 2024 11:19:16 +0800 Subject: [PATCH 8/9] Addressed Ningxin's comment --- onnxruntime/core/providers/webnn/builders/model.cc | 12 ++++++------ .../core/providers/webnn/builders/model_builder.cc | 4 ++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 48487b9e25240..87a35fc14cb35 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -26,11 +26,6 @@ Model::~Model() {} Status Model::Predict(const InlinedHashMap& inputs, const InlinedHashMap& outputs) { - // Allocate the MLNamedArrayBufferViews for inputs and outputs at every compute() - // because they will be transferred after compute() done. - // https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution - AllocateInputOutputBuffers(); - for (const auto& input : inputs) { const std::string& name = input.first; const struct OnnxTensorData tensor = input.second; @@ -129,7 +124,8 @@ Status Model::Predict(const InlinedHashMap& inputs, output_views.insert({name, view}); } - emscripten::val results = wnn_context_.call("compute", wnn_graph_, wnn_inputs_, wnn_outputs_).await(); + emscripten::val results = wnn_context_.call( + "compute", wnn_graph_, wnn_inputs_, wnn_outputs_).await(); // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer. for (const auto& output : outputs) { @@ -137,6 +133,10 @@ Status Model::Predict(const InlinedHashMap& inputs, emscripten::val view = output_views.at(name); view.call("set", results["outputs"][name]); } + // WebNN compute() method would return the input and output buffers via the promise + // resolution. Reuse the buffers to avoid additional allocation. + wnn_inputs_ = results["inputs"]; + wnn_outputs_ = results["outputs"]; return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 4abd244ba4648..56f7ead8ccf5d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -396,6 +396,10 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { model->SetOutputs(std::move(output_names_)); model->SetScalarOutputs(std::move(scalar_outputs_)); model->SetInputOutputInfo(std::move(input_output_info_)); + // Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews + // for inputs and outputs because they will be transferred after compute() done. + // https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution + model->AllocateInputOutputBuffers(); return Status::OK(); } From 564dbb17243d2b3a9e301607cca654ed2b19a461 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Tue, 23 Jan 2024 09:48:46 +0800 Subject: [PATCH 9/9] Fixed lint error --- onnxruntime/core/providers/webnn/builders/model.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 87a35fc14cb35..ef807a8c4fa26 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -125,7 +125,8 @@ Status Model::Predict(const InlinedHashMap& inputs, output_views.insert({name, view}); } emscripten::val results = wnn_context_.call( - "compute", wnn_graph_, wnn_inputs_, wnn_outputs_).await(); + "compute", wnn_graph_, wnn_inputs_, wnn_outputs_) + .await(); // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer. for (const auto& output : outputs) {