Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Support WebNN async API with Asyncify #19145

Merged
merged 9 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions js/web/lib/build-def.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ interface BuildDefinitions {
/**
* defines whether to disable the whole WebNN backend in the build.
*/
readonly DISABLE_WEBNN: boolean;
/**
* defines whether to disable the whole WebAssembly backend in the build.
*/
readonly DISABLE_WASM: boolean;
/**
* defines whether to disable proxy feature in WebAssembly backend in the build.
Expand Down
4 changes: 1 addition & 3 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) {
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
registerBackend('webnn', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
if (!BUILD_DEFS.DISABLE_WEBNN) {
registerBackend('webnn', wasmBackend, 9);
}
}

Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
2 changes: 1 addition & 1 deletion js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule {

_OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;

_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number;
_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise<number>;
_OrtReleaseSession(sessionHandle: number): void;
_OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtGetInputName(sessionHandle: number, index: number): number;
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise<void> => {
* @param epName
*/
export const initEp = async(env: Env, epName: string): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') {
if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
Expand Down Expand Up @@ -228,7 +228,7 @@ export const createSession = async(
await Promise.all(loadingPromises);
}

sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}
Expand Down
7 changes: 1 addition & 6 deletions js/web/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // <ORT_ROOT>/js/
const DEFAULT_DEFINE = {
'BUILD_DEFS.DISABLE_WEBGL': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'false',
'BUILD_DEFS.DISABLE_WEBNN': 'false',
'BUILD_DEFS.DISABLE_WASM': 'false',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
Expand Down Expand Up @@ -364,7 +363,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
Expand Down Expand Up @@ -397,7 +395,7 @@ async function main() {
// ort.webgpu[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.webgpu',
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'},
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'},
});
// ort.wasm[.min].js
await addAllWebBuildTasks({
Expand All @@ -411,7 +409,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WASM': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
// ort.wasm-core[.min].js
Expand All @@ -421,7 +418,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
Expand All @@ -434,7 +430,6 @@ async function main() {
'BUILD_DEFS.DISABLE_TRAINING': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
}
Expand Down
4 changes: 0 additions & 4 deletions js/web/script/test-runner-cli-args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs

const globalEnvFlags = parseGlobalEnvFlags(args);

if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) {
throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.');
}

// Options:
// --log-verbose=<...>
// --log-info=<...>
Expand Down
35 changes: 14 additions & 21 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,13 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
"The input of graph has unsupported type, name: ",
name, " type: ", tensor.tensor_info.data_type);
}
#ifdef ENABLE_WEBASSEMBLY_THREADS
// Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers.
// Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer.
// As Wasm ArrayBuffer is not detachable.
wnn_inputs_[name].call<void>("set", view);
#else
wnn_inputs_.set(name, view);
#endif
}

#ifdef ENABLE_WEBASSEMBLY_THREADS
// This vector uses for recording output buffers from WebNN graph compution when WebAssembly
// multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView,
// https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews
// and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory
// address is different from the non-shared one, additional memory copy is required here.
InlinedHashMap<std::string, emscripten::val> output_views;
#endif

for (const auto& output : outputs) {
const std::string& name = output.first;
const struct OnnxTensorData tensor = output.second;
Expand Down Expand Up @@ -131,21 +122,23 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
name, " type: ", tensor.tensor_info.data_type);
}

#ifdef ENABLE_WEBASSEMBLY_THREADS
output_views.insert({name, view});
#else
wnn_outputs_.set(name, view);
#endif
}
wnn_context_.call<emscripten::val>("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_);
#ifdef ENABLE_WEBASSEMBLY_THREADS
// Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer.
emscripten::val results = wnn_context_.call<emscripten::val>(
"compute", wnn_graph_, wnn_inputs_, wnn_outputs_)
.await();

// Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer.
for (const auto& output : outputs) {
const std::string& name = output.first;
emscripten::val view = output_views.at(name);
view.call<void>("set", wnn_outputs_[name]);
view.call<void>("set", results["outputs"][name]);
}
#endif
// WebNN compute() method would return the input and output buffers via the promise
// resolution. Reuse the buffers to avoid additional allocation.
wnn_inputs_ = results["inputs"];
wnn_outputs_ = results["outputs"];

return Status::OK();
}

Expand Down
12 changes: 5 additions & 7 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
for (auto& name : output_names_) {
named_operands.set(name, wnn_operands_.at(name));
}
emscripten::val wnn_graph = wnn_builder_.call<emscripten::val>("buildSync", named_operands);

emscripten::val wnn_graph = wnn_builder_.call<emscripten::val>("build", named_operands).await();
if (!wnn_graph.as<bool>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
}
Expand All @@ -395,13 +396,10 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
model->SetOutputs(std::move(output_names_));
model->SetScalarOutputs(std::move(scalar_outputs_));
model->SetInputOutputInfo(std::move(input_output_info_));
#ifdef ENABLE_WEBASSEMBLY_THREADS
// Pre-allocate the input and output tensors for the WebNN graph
// when WebAssembly multi-threads is enabled since WebNN API only
// accepts non-shared ArrayBufferView.
// https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews
// Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews
// for inputs and outputs because they will be transferred after compute() done.
// https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution
model->AllocateInputOutputBuffers();
#endif
return Status::OK();
}

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
if (webnn_power_flags.compare("default") != 0) {
context_options.set("powerPreference", emscripten::val(webnn_power_flags));
}
wnn_context_ = ml.call<emscripten::val>("createContextSync", context_options);

wnn_context_ = ml.call<emscripten::val>("createContext", context_options).await();
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/wasm/js_internal_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
};

// replace the original functions with asyncified versions
Module['_OrtCreateSession'] = jsepWrapAsync(
Module['_OrtCreateSession'],
() => Module['_OrtCreateSession'],
v => Module['_OrtCreateSession'] = v);
Module['_OrtRun'] = runAsync(jsepWrapAsync(
Module['_OrtRun'],
() => Module['_OrtRun'],
Expand Down
Loading