Skip to content

Commit

Permalink
[WebNN EP] Support WebNN async API with Asyncify (microsoft#19145)
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry authored Jan 24, 2024
1 parent 85c91de commit d94e2a9
Show file tree
Hide file tree
Showing 6 changed files with 5 additions and 20 deletions.
4 changes: 0 additions & 4 deletions web/lib/build-def.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ interface BuildDefinitions {
/**
* defines whether to disable the whole WebNN backend in the build.
*/
readonly DISABLE_WEBNN: boolean;
/**
* defines whether to disable the whole WebAssembly backend in the build.
*/
readonly DISABLE_WASM: boolean;
/**
* defines whether to disable proxy feature in WebAssembly backend in the build.
Expand Down
4 changes: 1 addition & 3 deletions web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) {
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
registerBackend('webnn', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
if (!BUILD_DEFS.DISABLE_WEBNN) {
registerBackend('webnn', wasmBackend, 9);
}
}

Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
2 changes: 1 addition & 1 deletion web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule {

_OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;

_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number;
_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise<number>;
_OrtReleaseSession(sessionHandle: number): void;
_OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtGetInputName(sessionHandle: number, index: number): number;
Expand Down
4 changes: 2 additions & 2 deletions web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise<void> => {
* @param epName
*/
export const initEp = async(env: Env, epName: string): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') {
if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
Expand Down Expand Up @@ -228,7 +228,7 @@ export const createSession = async(
await Promise.all(loadingPromises);
}

sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}
Expand Down
7 changes: 1 addition & 6 deletions web/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // <ORT_ROOT>/js/
const DEFAULT_DEFINE = {
'BUILD_DEFS.DISABLE_WEBGL': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'false',
'BUILD_DEFS.DISABLE_WEBNN': 'false',
'BUILD_DEFS.DISABLE_WASM': 'false',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
Expand Down Expand Up @@ -364,7 +363,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
Expand Down Expand Up @@ -397,7 +395,7 @@ async function main() {
// ort.webgpu[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.webgpu',
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'},
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'},
});
// ort.wasm[.min].js
await addAllWebBuildTasks({
Expand All @@ -411,7 +409,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WASM': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
// ort.wasm-core[.min].js
Expand All @@ -421,7 +418,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
Expand All @@ -434,7 +430,6 @@ async function main() {
'BUILD_DEFS.DISABLE_TRAINING': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
}
Expand Down
4 changes: 0 additions & 4 deletions web/script/test-runner-cli-args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs

const globalEnvFlags = parseGlobalEnvFlags(args);

if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) {
throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.');
}

// Options:
// --log-verbose=<...>
// --log-info=<...>
Expand Down

0 comments on commit d94e2a9

Please sign in to comment.