Skip to content

Commit

Permalink
[js/web] revise backend registration (#18715)
Browse files Browse the repository at this point in the history
### Description
This PR revises the backend registration.

The following describes the expected behavior after this change:
(**bolded are changed behavior**)

- (ort.min.js - built without webgpu support)
    - loading: do not register 'webgpu' backend
- creating session without EP list: use default EP list ['webnn', 'cpu',
'wasm']
- creating session with ['webgpu'] as EP list: should fail with backend
not available
- (ort.webgpu.min.js - built with webgpu support)
    - loading: **always register 'webgpu' backend**
( previous behavior: only register 'webgpu' backend when `navigator.gpu`
is available)
- creating session without EP list: use default EP list ['webgpu',
'webnn', 'cpu', 'wasm']
        - when WebGPU is available (win): use WebGPU backend
- when WebGPU is unavailable (android): **should fail backend init,**
and try to use next backend in the list, 'webnn'
(previous behavior: does not fail backend init, but fail in JSEP init,
which was too late to switch to next backend)
    - creating session with ['webgpu'] as EP list
        - when WebGPU is available (win): use WebGPU backend
- when WebGPU is unavailable (android): **should fail backend init, and
because no more EP listed, fail.


related PRs: #18190 #18144
  • Loading branch information
fs-eire authored Dec 20, 2023
1 parent c0142c9 commit 9a61388
Show file tree
Hide file tree
Showing 14 changed files with 393 additions and 390 deletions.
2 changes: 1 addition & 1 deletion js/common/lib/backend-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export const resolveBackend = async(backendHints: readonly string[]): Promise<Ba
const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init();
backendInfo.initPromise = backendInfo.backend.init(backendName);
}
await backendInfo.initPromise;
backendInfo.initialized = true;
Expand Down
2 changes: 1 addition & 1 deletion js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export interface Backend {
/**
* Initialize the backend asynchronously. Should throw when failed.
*/
init(): Promise<void>;
init(backendName: string): Promise<void>;

createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler>;
Expand Down
17 changes: 14 additions & 3 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {cpus} from 'node:os';
import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper';
import {initializeOrtEp, initializeWebAssemblyAndOrtRuntime} from './wasm/proxy-wrapper';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference';

/**
Expand Down Expand Up @@ -33,12 +33,23 @@ export const initializeFlags = (): void => {
};

export class OnnxruntimeWebAssemblyBackend implements Backend {
async init(): Promise<void> {
/**
* This function initializes the WebAssembly backend.
*
* This function will be called only once for each backend name. It will be called the first time when
* `ort.InferenceSession.create()` is called with a registered backend name.
*
* @param backendName - the registered backend name.
*/
async init(backendName: string): Promise<void> {
// populate wasm flags
initializeFlags();

// init wasm
await initializeWebAssemblyInstance();
await initializeWebAssemblyAndOrtRuntime();

// performe EP specific initialization
await initializeOrtEp(backendName);
}
createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler>;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
Expand Down
12 changes: 1 addition & 11 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,7 @@ export class WebGpuBackend {
*/
sessionExternalDataMapping: Map<number, Map<number, [number, GPUBuffer]>> = new Map();

async initialize(env: Env): Promise<void> {
if (!navigator.gpu) {
// WebGPU is not available.
throw new Error('WebGpuBackend: WebGPU is not available.');
}

const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error('WebGpuBackend: Failed to get GPU adapter.');
}

async initialize(env: Env, adapter: GPUAdapter): Promise<void> {
this.env = env;
const requiredFeatures: GPUFeatureName[] = [];
const deviceDescriptor: GPUDeviceDescriptor = {
Expand Down
130 changes: 71 additions & 59 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,64 +130,76 @@ class ComputeContextImpl implements ComputeContext {
}
}

export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
const init = module.jsepInit;
if (init && navigator.gpu) {
if (!env.wasm.simd) {
throw new Error(
'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using WebGPU EP');
}
const backend = new WebGpuBackend();
await backend.initialize(env);

init(
// backend
backend,

// jsepAlloc()
(size: number) => backend.alloc(size),

// jsepFree()
(ptr: number) => backend.free(ptr),

// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
backend.memcpy(src, dst);
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
const data = module.HEAPU8.subarray(src, src + size);
backend.upload(dst, data);
}
},

// jsepCopyAsync(src, dst, size)
async(gpuDataId: number, dataOffset: number, size: number):
Promise<void> => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`);

await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size));
},

// jsepCreateKernel
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
name, kernel, attribute,
env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),

// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),

// jsepRun
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
});
/**
* Initialize JSEP with WebGPU backend.
*
* This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called).
* This function expects:
* - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
* - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
* If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
* 'webgpu' backend.
*
* @param module - the ORT WebAssembly module
* @param env - the ORT environment variable (ort.env)
* @param gpuAdapter - the pre-created GPU adapter
*/
export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise<void> => {
const jsepInit = module.jsepInit;
if (!jsepInit) {
throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
}

const backend = new WebGpuBackend();
await backend.initialize(env, gpuAdapter);

jsepInit(
// backend
backend,

// jsepAlloc()
(size: number) => backend.alloc(size),

// jsepFree()
(ptr: number) => backend.free(ptr),

// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
backend.memcpy(src, dst);
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
const data = module.HEAPU8.subarray(src, src + size);
backend.upload(dst, data);
}
},

// jsepCopyAsync(src, dst, size)
async(gpuDataId: number, dataOffset: number, size: number):
Promise<void> => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`);

await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size));
},

// jsepCreateKernel
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
name, kernel, attribute,
env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),

// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),

// jsepRun
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
});
};
53 changes: 31 additions & 22 deletions js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import type {Env, InferenceSession, Tensor} from 'onnxruntime-common';

/**
* Among all the tensor locations, only 'cpu' is serializable.
*/
export type SerializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu'];

Expand All @@ -12,51 +15,61 @@ export type GpuBufferMetadata = {
dispose?: () => void;
};

/**
* Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable.
*/
export type UnserializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']|
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned'];

/**
* Tensor metadata is a tuple of [dataType, dims, data, location], where
* - dataType: tensor data type
* - dims: tensor dimensions
* - data: tensor data, which can be one of the following depending on the location:
* - cpu: Uint8Array
* - cpu-pinned: Uint8Array
* - gpu-buffer: GpuBufferMetadata
* - location: tensor data location
*/
export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata;

export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]];

export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number];
export type SerializableInternalBuffer = [bufferOffset: number, bufferLength: number];

interface MessageError {
err?: string;
}

interface MessageInitWasm extends MessageError {
type: 'init-wasm';
in ?: Env.WebAssemblyFlags;
}

interface MessageInitOrt extends MessageError {
type: 'init-ort';
in ?: Env;
out?: never;
}

interface MessageCreateSessionAllocate extends MessageError {
type: 'create_allocate';
in ?: {model: Uint8Array};
out?: SerializableModeldata;
interface MessageInitEp extends MessageError {
type: 'init-ep';
in ?: {env: Env; epName: string};
out?: never;
}

interface MessageCreateSessionFinalize extends MessageError {
type: 'create_finalize';
in ?: {modeldata: SerializableModeldata; options?: InferenceSession.SessionOptions};
out?: SerializableSessionMetadata;
interface MessageCopyFromExternalBuffer extends MessageError {
type: 'copy-from';
in ?: {buffer: Uint8Array};
out?: SerializableInternalBuffer;
}

interface MessageCreateSession extends MessageError {
type: 'create';
in ?: {model: Uint8Array; options?: InferenceSession.SessionOptions};
in ?: {model: SerializableInternalBuffer|Uint8Array; options?: InferenceSession.SessionOptions};
out?: SerializableSessionMetadata;
}

interface MessageReleaseSession extends MessageError {
type: 'release';
in ?: number;
out?: never;
}

interface MessageRun extends MessageError {
Expand All @@ -71,12 +84,8 @@ interface MessageRun extends MessageError {
interface MesssageEndProfiling extends MessageError {
type: 'end-profiling';
in ?: number;
out?: never;
}

interface MessageIsOrtEnvInitialized extends MessageError {
type: 'is-ort-env-initialized';
out?: boolean;
}

export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
export type OrtWasmMessage = MessageInitWasm|MessageInitEp|MessageCopyFromExternalBuffer|MessageCreateSession|
MessageReleaseSession|MessageRun|MesssageEndProfiling;
Loading

0 comments on commit 9a61388

Please sign in to comment.