diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index e129c6971a85c..3e1e833addb91 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -82,7 +82,7 @@ export const resolveBackend = async(backendHints: readonly string[]): Promise; + init(backendName: string): Promise; createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise; diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 78edcc90f55f9..2d123cdb71290 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -4,7 +4,7 @@ import {cpus} from 'node:os'; import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; -import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; +import {initializeOrtEp, initializeWebAssemblyAndOrtRuntime} from './wasm/proxy-wrapper'; import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference'; /** @@ -33,12 +33,23 @@ export const initializeFlags = (): void => { }; export class OnnxruntimeWebAssemblyBackend implements Backend { - async init(): Promise { + /** + * This function initializes the WebAssembly backend. + * + * This function will be called only once for each backend name. It will be called the first time when + * `ort.InferenceSession.create()` is called with a registered backend name. + * + * @param backendName - the registered backend name. + */ + async init(backendName: string): Promise { // populate wasm flags initializeFlags(); // init wasm - await initializeWebAssemblyInstance(); + await initializeWebAssemblyAndOrtRuntime(); + + // performe EP specific initialization + await initializeOrtEp(backendName); } createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions): Promise; diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 6060271ced156..499327741c82b 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -21,7 +21,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : require('./backend-wasm-training').wasmBackend; - if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { + if (!BUILD_DEFS.DISABLE_WEBGPU) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 4f4a06c37a94f..6c3d22352772e 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -144,17 +144,7 @@ export class WebGpuBackend { */ sessionExternalDataMapping: Map> = new Map(); - async initialize(env: Env): Promise { - if (!navigator.gpu) { - // WebGPU is not available. - throw new Error('WebGpuBackend: WebGPU is not available.'); - } - - const adapter = await navigator.gpu.requestAdapter(); - if (!adapter) { - throw new Error('WebGpuBackend: Failed to get GPU adapter.'); - } - + async initialize(env: Env, adapter: GPUAdapter): Promise { this.env = env; const requiredFeatures: GPUFeatureName[] = []; const deviceDescriptor: GPUDeviceDescriptor = { diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index e6db631c44eea..cad1e87b24a51 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -130,64 +130,76 @@ class ComputeContextImpl implements ComputeContext { } } -export const init = async(module: OrtWasmModule, env: Env): Promise => { - const init = module.jsepInit; - if (init && navigator.gpu) { - if (!env.wasm.simd) { - throw new Error( - 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using WebGPU EP'); - } - const backend = new WebGpuBackend(); - await backend.initialize(env); - - init( - // backend - backend, - - // jsepAlloc() - (size: number) => backend.alloc(size), - - // jsepFree() - (ptr: number) => backend.free(ptr), - - // jsepCopy(src, dst, size, isSourceGpu) - (src: number, dst: number, size: number, isSourceGpu = false) => { - if (isSourceGpu) { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`); - backend.memcpy(src, dst); - } else { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); - const data = module.HEAPU8.subarray(src, src + size); - backend.upload(dst, data); - } - }, - - // jsepCopyAsync(src, dst, size) - async(gpuDataId: number, dataOffset: number, size: number): - Promise => { - LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); - - await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size)); - }, - - // jsepCreateKernel - (name: string, kernel: number, attribute: unknown) => backend.createKernel( - name, kernel, attribute, - env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`), - - // jsepReleaseKernel - (kernel: number) => backend.releaseKernel(kernel), - - // jsepRun - (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { - LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ - contextDataOffset}`); - const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context, errors); - }); +/** + * Initialize JSEP with WebGPU backend. + * + * This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called). + * This function expects: + * - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false). + * - WebGPU is available in current environment. (a valid GPUAdapter is passed in) + * If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate + * 'webgpu' backend. + * + * @param module - the ORT WebAssembly module + * @param env - the ORT environment variable (ort.env) + * @param gpuAdapter - the pre-created GPU adapter + */ +export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise => { + const jsepInit = module.jsepInit; + if (!jsepInit) { + throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.'); } + + const backend = new WebGpuBackend(); + await backend.initialize(env, gpuAdapter); + + jsepInit( + // backend + backend, + + // jsepAlloc() + (size: number) => backend.alloc(size), + + // jsepFree() + (ptr: number) => backend.free(ptr), + + // jsepCopy(src, dst, size, isSourceGpu) + (src: number, dst: number, size: number, isSourceGpu = false) => { + if (isSourceGpu) { + LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`); + backend.memcpy(src, dst); + } else { + LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); + const data = module.HEAPU8.subarray(src, src + size); + backend.upload(dst, data); + } + }, + + // jsepCopyAsync(src, dst, size) + async(gpuDataId: number, dataOffset: number, size: number): + Promise => { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); + + await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size)); + }, + + // jsepCreateKernel + (name: string, kernel: number, attribute: unknown) => backend.createKernel( + name, kernel, attribute, + env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`), + + // jsepReleaseKernel + (kernel: number) => backend.releaseKernel(kernel), + + // jsepRun + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ + contextDataOffset}`); + const context = new ComputeContextImpl(module, backend, contextDataOffset); + return backend.computeKernel(kernel, context, errors); + }); }; diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index efeb086256cf3..02246c9ee4767 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -3,6 +3,9 @@ import type {Env, InferenceSession, Tensor} from 'onnxruntime-common'; +/** + * Among all the tensor locations, only 'cpu' is serializable. + */ export type SerializableTensorMetadata = [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu']; @@ -12,15 +15,28 @@ export type GpuBufferMetadata = { dispose?: () => void; }; +/** + * Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable. + */ export type UnserializableTensorMetadata = [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']| [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; +/** + * Tensor metadata is a tuple of [dataType, dims, data, location], where + * - dataType: tensor data type + * - dims: tensor dimensions + * - data: tensor data, which can be one of the following depending on the location: + * - cpu: Uint8Array + * - cpu-pinned: Uint8Array + * - gpu-buffer: GpuBufferMetadata + * - location: tensor data location + */ export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata; export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]]; -export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number]; +export type SerializableInternalBuffer = [bufferOffset: number, bufferLength: number]; interface MessageError { err?: string; @@ -28,35 +44,32 @@ interface MessageError { interface MessageInitWasm extends MessageError { type: 'init-wasm'; - in ?: Env.WebAssemblyFlags; -} - -interface MessageInitOrt extends MessageError { - type: 'init-ort'; in ?: Env; + out?: never; } -interface MessageCreateSessionAllocate extends MessageError { - type: 'create_allocate'; - in ?: {model: Uint8Array}; - out?: SerializableModeldata; +interface MessageInitEp extends MessageError { + type: 'init-ep'; + in ?: {env: Env; epName: string}; + out?: never; } -interface MessageCreateSessionFinalize extends MessageError { - type: 'create_finalize'; - in ?: {modeldata: SerializableModeldata; options?: InferenceSession.SessionOptions}; - out?: SerializableSessionMetadata; +interface MessageCopyFromExternalBuffer extends MessageError { + type: 'copy-from'; + in ?: {buffer: Uint8Array}; + out?: SerializableInternalBuffer; } interface MessageCreateSession extends MessageError { type: 'create'; - in ?: {model: Uint8Array; options?: InferenceSession.SessionOptions}; + in ?: {model: SerializableInternalBuffer|Uint8Array; options?: InferenceSession.SessionOptions}; out?: SerializableSessionMetadata; } interface MessageReleaseSession extends MessageError { type: 'release'; in ?: number; + out?: never; } interface MessageRun extends MessageError { @@ -71,12 +84,8 @@ interface MessageRun extends MessageError { interface MesssageEndProfiling extends MessageError { type: 'end-profiling'; in ?: number; + out?: never; } -interface MessageIsOrtEnvInitialized extends MessageError { - type: 'is-ort-env-initialized'; - out?: boolean; -} - -export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize| - MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized; +export type OrtWasmMessage = MessageInitWasm|MessageInitEp|MessageCopyFromExternalBuffer|MessageCreateSession| + MessageReleaseSession|MessageRun|MesssageEndProfiling; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 1cb6d9e391e4f..4df524cdcfb22 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -36,104 +36,82 @@ declare global { } import {OrtWasmMessage, SerializableTensorMetadata} from '../proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; +import {createSession, copyFromExternalBuffer, endProfiling, extractTransferableBuffers, initEp, initRuntime, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; self.onmessage = (ev: MessageEvent): void => { - switch (ev.data.type) { - case 'init-wasm': - try { - initializeWebAssembly(ev.data.in!) + const {type, in : message} = ev.data; + try { + switch (type) { + case 'init-wasm': + initializeWebAssembly(message!.wasm) .then( - () => postMessage({type: 'init-wasm'} as OrtWasmMessage), - err => postMessage({type: 'init-wasm', err} as OrtWasmMessage)); - } catch (err) { - postMessage({type: 'init-wasm', err} as OrtWasmMessage); - } - break; - case 'init-ort': - try { - initRuntime(ev.data.in!).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({ - type: 'init-ort', - err - } as OrtWasmMessage)); - } catch (err) { - postMessage({type: 'init-ort', err} as OrtWasmMessage); - } - break; - case 'create_allocate': - try { - const {model} = ev.data.in!; - const modeldata = createSessionAllocate(model); - postMessage({type: 'create_allocate', out: modeldata} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'create_allocate', err} as OrtWasmMessage); + () => { + initRuntime(message!).then( + () => { + postMessage({type}); + }, + err => { + postMessage({type, err}); + }); + }, + err => { + postMessage({type, err}); + }); + break; + case 'init-ep': { + const {epName, env} = message!; + initEp(env, epName) + .then( + () => { + postMessage({type}); + }, + err => { + postMessage({type, err}); + }); + break; } - break; - case 'create_finalize': - try { - const {modeldata, options} = ev.data.in!; - const sessionMetadata = createSessionFinalize(modeldata, options); - postMessage({type: 'create_finalize', out: sessionMetadata} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'create_finalize', err} as OrtWasmMessage); + case 'copy-from': { + const {buffer} = message!; + const bufferData = copyFromExternalBuffer(buffer); + postMessage({type, out: bufferData} as OrtWasmMessage); + break; } - break; - case 'create': - try { - const {model, options} = ev.data.in!; + case 'create': { + const {model, options} = message!; const sessionMetadata = createSession(model, options); - postMessage({type: 'create', out: sessionMetadata} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'create', err} as OrtWasmMessage); + postMessage({type, out: sessionMetadata} as OrtWasmMessage); + break; } - break; - case 'release': - try { - releaseSession(ev.data.in!); - postMessage({type: 'release'} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'release', err} as OrtWasmMessage); - } - break; - case 'run': - try { - const {sessionId, inputIndices, inputs, outputIndices, options} = ev.data.in!; + case 'release': + releaseSession(message!); + postMessage({type}); + break; + case 'run': { + const {sessionId, inputIndices, inputs, outputIndices, options} = message!; run(sessionId, inputIndices, inputs, outputIndices, new Array(outputIndices.length).fill(null), options) .then( outputs => { if (outputs.some(o => o[3] !== 'cpu')) { - postMessage({type: 'run', err: 'Proxy does not support non-cpu tensor location.'}); + postMessage({type, err: 'Proxy does not support non-cpu tensor location.'}); } else { postMessage( - {type: 'run', out: outputs} as OrtWasmMessage, + {type, out: outputs} as OrtWasmMessage, extractTransferableBuffers(outputs as SerializableTensorMetadata[])); } }, err => { - postMessage({type: 'run', err} as OrtWasmMessage); + postMessage({type, err}); }); - } catch (err) { - postMessage({type: 'run', err} as OrtWasmMessage); - } - break; - case 'end-profiling': - try { - const handler = ev.data.in!; - endProfiling(handler); - postMessage({type: 'end-profiling'} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'end-profiling', err} as OrtWasmMessage); - } - break; - case 'is-ort-env-initialized': - try { - const ortEnvInitialized = isOrtEnvInitialized(); - postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage); + break; } - break; - default: + case 'end-profiling': + endProfiling(message!); + postMessage({type}); + break; + default: + } + } catch (err) { + postMessage({type, err} as OrtWasmMessage); } }; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 069a1fa452dbc..86017a4ec6904 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env, env, InferenceSession} from 'onnxruntime-common'; +import {env, InferenceSession} from 'onnxruntime-common'; -import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; +import {OrtWasmMessage, SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import * as core from './wasm-core-impl'; import {initializeWebAssembly} from './wasm-factory'; @@ -13,18 +13,18 @@ let initializing = false; let initialized = false; let aborted = false; -// resolve; reject -type PromiseCallbacks = [(result: T) => void, (reason: unknown) => void]; - +type PromiseCallbacks = [resolve: (result: T) => void, reject: (reason: unknown) => void]; let initWasmCallbacks: PromiseCallbacks; -let initOrtCallbacks: PromiseCallbacks; -const createSessionAllocateCallbacks: Array> = []; -const createSessionFinalizeCallbacks: Array> = []; -const createSessionCallbacks: Array> = []; -const releaseSessionCallbacks: Array> = []; -const runCallbacks: Array> = []; -const endProfilingCallbacks: Array> = []; -const isOrtEnvInitializedCallbacks: Array> = []; +const queuedCallbacks: Map>> = new Map(); + +const enqueueCallbacks = (type: OrtWasmMessage['type'], callbacks: PromiseCallbacks): void => { + const queue = queuedCallbacks.get(type); + if (queue) { + queue.push(callbacks); + } else { + queuedCallbacks.set(type, [callbacks]); + } +}; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { @@ -44,82 +44,40 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { initWasmCallbacks[0](); } break; - case 'init-ort': - if (ev.data.err) { - initOrtCallbacks[1](ev.data.err); - } else { - initOrtCallbacks[0](); - } - break; - case 'create_allocate': - if (ev.data.err) { - createSessionAllocateCallbacks.shift()![1](ev.data.err); - } else { - createSessionAllocateCallbacks.shift()![0](ev.data.out!); - } - break; - case 'create_finalize': - if (ev.data.err) { - createSessionFinalizeCallbacks.shift()![1](ev.data.err); - } else { - createSessionFinalizeCallbacks.shift()![0](ev.data.out!); - } - break; + case 'init-ep': + case 'copy-from': case 'create': - if (ev.data.err) { - createSessionCallbacks.shift()![1](ev.data.err); - } else { - createSessionCallbacks.shift()![0](ev.data.out!); - } - break; case 'release': - if (ev.data.err) { - releaseSessionCallbacks.shift()![1](ev.data.err); - } else { - releaseSessionCallbacks.shift()![0](); - } - break; case 'run': + case 'end-profiling': { + const callbacks = queuedCallbacks.get(ev.data.type)!; if (ev.data.err) { - runCallbacks.shift()![1](ev.data.err); - } else { - runCallbacks.shift()![0](ev.data.out!); - } - break; - case 'end-profiling': - if (ev.data.err) { - endProfilingCallbacks.shift()![1](ev.data.err); - } else { - endProfilingCallbacks.shift()![0](); - } - break; - case 'is-ort-env-initialized': - if (ev.data.err) { - isOrtEnvInitializedCallbacks.shift()![1](ev.data.err); + callbacks.shift()![1](ev.data.err); } else { - isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); + callbacks.shift()![0](ev.data.out!); } break; + } default: } }; const scriptSrc = typeof document !== 'undefined' ? (document?.currentScript as HTMLScriptElement)?.src : undefined; -export const initializeWebAssemblyInstance = async(): Promise => { - if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { - if (initialized) { - return; - } - if (initializing) { - throw new Error('multiple calls to \'initWasm()\' detected.'); - } - if (aborted) { - throw new Error('previous call to \'initWasm()\' failed.'); - } +export const initializeWebAssemblyAndOrtRuntime = async(): Promise => { + if (initialized) { + return; + } + if (initializing) { + throw new Error('multiple calls to \'initWasm()\' detected.'); + } + if (aborted) { + throw new Error('previous call to \'initWasm()\' failed.'); + } - initializing = true; + initializing = true; + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { // overwrite wasm filepaths if (env.wasm.wasmPaths === undefined) { if (scriptSrc && scriptSrc.indexOf('blob:') !== 0) { @@ -142,78 +100,78 @@ export const initializeWebAssemblyInstance = async(): Promise => { proxyWorker.onmessage = onProxyWorkerMessage; URL.revokeObjectURL(workerUrl); initWasmCallbacks = [resolve, reject]; - const message: OrtWasmMessage = {type: 'init-wasm', in : env.wasm}; + const message: OrtWasmMessage = {type: 'init-wasm', in : env}; proxyWorker.postMessage(message); }); } else { - return initializeWebAssembly(env.wasm); + try { + await initializeWebAssembly(env.wasm); + await core.initRuntime(env); + initialized = true; + } catch (e) { + aborted = true; + throw e; + } finally { + initializing = false; + } } }; -export const initializeRuntime = async(env: Env): Promise => { +export const initializeOrtEp = async(epName: string): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { - initOrtCallbacks = [resolve, reject]; - const message: OrtWasmMessage = {type: 'init-ort', in : env}; + enqueueCallbacks('init-ep', [resolve, reject]); + const message: OrtWasmMessage = {type: 'init-ep', in : {epName, env}}; proxyWorker!.postMessage(message); }); } else { - await core.initRuntime(env); + await core.initEp(env, epName); } }; -export const createSessionAllocate = async(model: Uint8Array): Promise => { +export const copyFromExternalBuffer = async(buffer: Uint8Array): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); - return new Promise((resolve, reject) => { - createSessionAllocateCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'create_allocate', in : {model}}; - proxyWorker!.postMessage(message, [model.buffer]); + return new Promise((resolve, reject) => { + enqueueCallbacks('copy-from', [resolve, reject]); + const message: OrtWasmMessage = {type: 'copy-from', in : {buffer}}; + proxyWorker!.postMessage(message, [buffer.buffer]); }); } else { - return core.createSessionAllocate(model); + return core.copyFromExternalBuffer(buffer); } }; -export const createSessionFinalize = async(modeldata: SerializableModeldata, options?: InferenceSession.SessionOptions): - Promise => { - if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { - ensureWorker(); - return new Promise((resolve, reject) => { - createSessionFinalizeCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'create_finalize', in : {modeldata, options}}; - proxyWorker!.postMessage(message); - }); - } else { - return core.createSessionFinalize(modeldata, options); - } - }; - export const createSession = - async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise => { - if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { - // check unsupported options - if (options?.preferredOutputLocation) { - throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); - } - ensureWorker(); - return new Promise((resolve, reject) => { - createSessionCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'create', in : {model, options}}; - proxyWorker!.postMessage(message, [model.buffer]); - }); - } else { - return core.createSession(model, options); - } -}; + async(model: SerializableInternalBuffer|Uint8Array, options?: InferenceSession.SessionOptions): + Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check unsupported options + if (options?.preferredOutputLocation) { + throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); + } + ensureWorker(); + return new Promise((resolve, reject) => { + enqueueCallbacks('create', [resolve, reject]); + const message: OrtWasmMessage = {type: 'create', in : {model, options}}; + const transferable: Transferable[] = []; + if (model instanceof Uint8Array) { + transferable.push(model.buffer); + } + proxyWorker!.postMessage(message, transferable); + }); + } else { + return core.createSession(model, options); + } + }; export const releaseSession = async(sessionId: number): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { - releaseSessionCallbacks.push([resolve, reject]); + enqueueCallbacks('release', [resolve, reject]); const message: OrtWasmMessage = {type: 'release', in : sessionId}; proxyWorker!.postMessage(message); }); @@ -236,7 +194,7 @@ export const run = async( } ensureWorker(); return new Promise((resolve, reject) => { - runCallbacks.push([resolve, reject]); + enqueueCallbacks('run', [resolve, reject]); const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}}; @@ -251,7 +209,7 @@ export const endProfiling = async(sessionId: number): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { - endProfilingCallbacks.push([resolve, reject]); + enqueueCallbacks('end-profiling', [resolve, reject]); const message: OrtWasmMessage = {type: 'end-profiling', in : sessionId}; proxyWorker!.postMessage(message); }); @@ -259,16 +217,3 @@ export const endProfiling = async(sessionId: number): Promise => { core.endProfiling(sessionId); } }; - -export const isOrtEnvInitialized = async(): Promise => { - if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { - ensureWorker(); - return new Promise((resolve, reject) => { - isOrtEnvInitializedCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'is-ort-env-initialized'}; - proxyWorker!.postMessage(message); - }); - } else { - return core.isOrtEnvInitialized(); - } -}; diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts index 3ca34d957c572..b62287483208a 100644 --- a/js/web/lib/wasm/session-handler-inference.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -2,14 +2,12 @@ // Licensed under the MIT License. import {readFile} from 'node:fs/promises'; -import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; +import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; +import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; +import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; -let runtimeInitializationPromise: Promise|undefined; - export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { switch (tensor.location) { case 'cpu': @@ -44,7 +42,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan inputNames: string[]; outputNames: string[]; - async createSessionAllocate(path: string): Promise { + async fetchModelAndCopyToWasmMemory(path: string): Promise { // fetch model from url and move to wasm heap. The arraybufffer that held the http // response is freed once we return const response = await fetch(path); @@ -52,33 +50,26 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan throw new Error(`failed to load model: ${path}`); } const arrayBuffer = await response.arrayBuffer(); - return createSessionAllocate(new Uint8Array(arrayBuffer)); + return copyFromExternalBuffer(new Uint8Array(arrayBuffer)); } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!(await isOrtEnvInitialized())) { - if (!runtimeInitializationPromise) { - runtimeInitializationPromise = initializeRuntime(env); - } - await runtimeInitializationPromise; - runtimeInitializationPromise = undefined; - } + let model: Parameters[0]; if (typeof pathOrBuffer === 'string') { if (typeof process !== 'undefined' && process.versions && process.versions.node) { // node - const model = await readFile(pathOrBuffer); - [this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options); + model = await readFile(pathOrBuffer); } else { // browser - // fetch model and move to wasm heap. - const modelData: SerializableModeldata = await this.createSessionAllocate(pathOrBuffer); - // create the session - [this.sessionId, this.inputNames, this.outputNames] = await createSessionFinalize(modelData, options); + // fetch model and copy to wasm heap. + model = await this.fetchModelAndCopyToWasmMemory(pathOrBuffer); } } else { - [this.sessionId, this.inputNames, this.outputNames] = await createSession(pathOrBuffer, options); + model = pathOrBuffer; } + + [this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options); } async dispose(): Promise { diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index 71815f21e650a..e35759192fe3c 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; +import {InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; -import {SerializableModeldata, TensorMetadata} from './proxy-messages'; +import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; -import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {copyFromExternalBuffer} from './wasm-core-impl'; import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, lazyResetGrad, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { @@ -18,7 +18,7 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes evalInputNames: string[] = []; evalOutputNames: string[] = []; - async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { let buffer: Uint8Array; if (typeof uriOrBuffer === 'string') { const response = await fetch(uriOrBuffer); @@ -27,21 +27,18 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } else { buffer = uriOrBuffer; } - return createSessionAllocate(buffer); + return copyFromExternalBuffer(buffer); } async createTrainingSession( checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions) { - if (!isOrtEnvInitialized()) { - await initRuntime(env); - } - const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer); // 0 is supposed to be the nullptr - let evalModelData: SerializableModeldata = [0, 0]; - let optimizerModelData: SerializableModeldata = [0, 0]; + let evalModelData: SerializableInternalBuffer = [0, 0]; + let optimizerModelData: SerializableInternalBuffer = [0, 0]; if (evalModelUriOrBuffer !== '') { evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 3aacf8f4d90e0..a9dfd9218bb6f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -3,37 +3,60 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; +import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; -let ortEnvInitialized = false; +// #region Initializations /** - * get the input/output count of the session. - * @param sessionHandle the handle representing the session. should be non-zero. - * @returns a tuple including 2 numbers, representing the input count and output count. + * There are 4 different "initialization" steps for ORT. They happen in different places and different time. + * + * 1. JavaScript initialization for onnxruntime-common and onnxruntime-web. + * This is the first initialization step. In this step, onnxruntime-web calls onnxruntime-common's registerBackend() + * function multiple times to register all the available backends. The backend registration is very fast. It only + * registers the backend name with the uninitialized backend object. No heavy initialization is done in this step. + * Refer to web/lib/index.ts for the backend registration. + * + * 2. WebAssembly artifact initialization. + * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or + * `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings: + * - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled. + * - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated + * JavaScript code to initialize the WebAssembly runtime. + * - if proxy is enabled, this step happens in the proxy worker using message 'init-wasm'. + * - downloading the 'ort-wasm{...}.wasm' file is done in this step. + * - if multi-thread is enabled, one or more webworker will be created to initialize the PThread threadpool. + * + * 3. ORT environment initialization. + * This happens after step 2. In this step, onnxruntime-web performs ONNX Runtime environment initialization. + * Function `_OrtInit()` is called in this step. + * - if proxy is enabled, this step happens in the proxy worker using message 'init-ort'. + * - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step. + * + * 4. Session initialization. + * This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3 + * steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the + * followings: + * If the parameter is a URL: + * - download the model data from the URL. + * - copy the model data to the WASM heap. (proxy: 'copy-from') + * - dereference the model buffer. This step allows the original ArrayBuffer to be garbage collected. + * - call `_OrtCreateSession()` to create the session. (proxy: 'create') + * + * If the parameter is a Uint8Array object: + * - copy the model data to the WASM heap. (proxy: 'copy-from') + * - call `_OrtCreateSession()` to create the session. (proxy: 'create') + * + * */ -const getSessionInputOutputCount = (sessionHandle: number): [number, number] => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - try { - const dataOffset = wasm.stackAlloc(8); - const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4); - if (errorCode !== 0) { - checkLastError('Can\'t get session input/output count.'); - } - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; - } finally { - wasm.stackRestore(stack); - } -}; /** * initialize ORT environment. + * * @param numThreads SetGlobalIntraOpNumThreads(numThreads) * @param loggingLevel CreateEnv(static_cast(logging_level)) */ @@ -51,18 +74,41 @@ const initOrt = (numThreads: number, loggingLevel: number): void => { export const initRuntime = async(env: Env): Promise => { // init ORT initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel)); +}; + +/** + * perform EP specific initialization. + * + * @param env + * @param epName + */ +export const initEp = async(env: Env, epName: string): Promise => { + if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { + // perform WebGPU availability check + if (typeof navigator === 'undefined' || !navigator.gpu) { + throw new Error('WebGPU is not supported in current environment'); + } + const adapter = await navigator.gpu.requestAdapter(); + if (!adapter) { + throw new Error( + 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + } + + if (!env.wasm.simd) { + throw new Error( + 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP'); + } - if (!BUILD_DEFS.DISABLE_WEBGPU) { // init JSEP if available // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires const initJsep = require('./jsep/init').init; - await initJsep(getInstance(), env); + await initJsep(getInstance(), env, adapter); } - - ortEnvInitialized = true; }; +// #endregion Initializations + /** * valid data locations for input/output tensors. */ @@ -97,13 +143,33 @@ type SessionMetadata = [ const activeSessions = new Map(); -export const isOrtEnvInitialized = (): boolean => ortEnvInitialized; +/** + * get the input/output count of the session. + * @param sessionHandle the handle representing the session. should be non-zero. + * @returns a tuple including 2 numbers, representing the input count and output count. + */ +const getSessionInputOutputCount = (sessionHandle: number): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } finally { + wasm.stackRestore(stack); + } +}; /** - * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. + * allocate the memory and memcpy the external buffer. + * + * @param model - the external buffer containing the model data. Must not be the same buffer as the WASM heap. * @returns a 2-elements tuple - the pointer and size of the allocated buffer */ -export const createSessionAllocate = (model: Uint8Array): [number, number] => { +export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => { const wasm = getInstance(); const modelDataOffset = wasm._malloc(model.byteLength); if (modelDataOffset === 0) { @@ -114,15 +180,30 @@ export const createSessionAllocate = (model: Uint8Array): [number, number] => { }; /** - * create an inference session using the prepared buffer containing the model data. - * @param modelData a 2-elements tuple containing the pointer and size of the model data buffer. + * create an inference session from a model data buffer. + * + * @param modelData - either a Uint8Array object representing the model data, or a 2-elements tuple containing the + * pointer and size of the model data buffer. * @param options an optional session options object. * @returns a 3-elements tuple containing [session handle, input names, output names] */ -export const createSessionFinalize = - (modelData: SerializableModeldata, options?: InferenceSession.SessionOptions): SerializableSessionMetadata => { +export const createSession = + (modelData: Uint8Array|SerializableInternalBuffer, + options?: InferenceSession.SessionOptions): SerializableSessionMetadata => { + let modelDataOffset: number, modelDataLength: number; const wasm = getInstance(); + if (Array.isArray(modelData)) { + // if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data + [modelDataOffset, modelDataLength] = modelData; + } else if (modelData.buffer === wasm.HEAPU8.buffer) { + // if model data uses the same buffer as the WASM heap, we don't need to copy it. + [modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength]; + } else { + // otherwise, copy the model data to the WASM heap. + [modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData); + } + let sessionHandle = 0; let sessionOptionsHandle = 0; let ioBindingHandle = 0; @@ -133,7 +214,7 @@ export const createSessionFinalize = try { [sessionOptionsHandle, allocs] = setSessionOptions(options); - sessionHandle = wasm._OrtCreateSession(modelData[0], modelData[1], sessionOptionsHandle); + sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } @@ -201,7 +282,7 @@ export const createSessionFinalize = } throw e; } finally { - wasm._free(modelData[0]); + wasm._free(modelDataOffset); if (sessionOptionsHandle !== 0) { wasm._OrtReleaseSessionOptions(sessionOptionsHandle); } @@ -209,17 +290,6 @@ export const createSessionFinalize = } }; - -/** - * create an instance of InferenceSession. - * @returns the metadata of InferenceSession. 0-value handle for failure. - */ -export const createSession = - (model: Uint8Array, options?: InferenceSession.SessionOptions): SerializableSessionMetadata => { - const modelData: SerializableModeldata = createSessionAllocate(model); - return createSessionFinalize(modelData, options); - }; - export const releaseSession = (sessionId: number): void => { const wasm = getInstance(); const session = activeSessions.get(sessionId); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 0cc28188a6093..c65178e2358d2 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -3,7 +3,7 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, TensorMetadata} from './proxy-messages'; +import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; @@ -32,7 +32,7 @@ const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero } }; -export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { +export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => { const wasm = getInstance(); const [checkpointDataOffset, checkpointDataLength] = checkpointData; @@ -108,8 +108,8 @@ export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: }; export const createTrainingSessionHandle = - (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, - optimizerModelData: SerializableModeldata, options: InferenceSession.SessionOptions): number => { + (checkpointHandle: number, trainModelData: SerializableInternalBuffer, evalModelData: SerializableInternalBuffer, + optimizerModelData: SerializableInternalBuffer, options: InferenceSession.SessionOptions): number => { const wasm = getInstance(); let trainingSessionHandle = 0; diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 29acc07e118f9..5e9b0910a2c68 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -850,7 +850,7 @@ export class ProtoOpTestContext { this.backendHint = test.backend!; this.ioBindingMode = test.ioBinding; - this.loadedData = onnx.ModelProto.encode(model).finish(); + this.loadedData = onnx.ModelProto.encode(model).finish().slice(); // in debug mode, open a new tab in browser for the generated onnx model. if (ort.env.debug) {