diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 18a068e0ced8b..c4fe1f1db38af 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -17,7 +17,16 @@ class OnnxjsBackend implements Backend { // onnxruntime-common). // In future we should remove Session.Config and use InferenceSession.SessionOptions. // Currently we allow this to happen to make test runner work. - const session = new Session(options as unknown as Session.Config); + const onnxjsOptions = {...options as unknown as Session.Config}; + if (!onnxjsOptions.backendHint && options?.executionProviders && options?.executionProviders[0]) { + const ep = options?.executionProviders[0]; + if (typeof ep === 'string') { + onnxjsOptions.backendHint = ep; + } else { + onnxjsOptions.backendHint = ep.name; + } + } + const session = new Session(onnxjsOptions); // typescript cannot merge method override correctly (so far in 4.2.3). need if-else to call the method. if (typeof pathOrBuffer === 'string') { diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index fea2cd17e8318..c77a2a7682d03 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -12,6 +12,7 @@ import {registerBackend} from 'onnxruntime-common'; if (!BUILD_DEFS.DISABLE_WEBGL) { const onnxjsBackend = require('./backend-onnxjs').onnxjsBackend; registerBackend('webgl', onnxjsBackend, -1); + registerBackend('webgpu', onnxjsBackend, 999); // set to 999 as the highest priority } if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = require('./backend-wasm').wasmBackend; diff --git a/js/web/lib/onnxjs/backend.ts b/js/web/lib/onnxjs/backend.ts index a363ec9f21368..5ac77ae2f5fcb 100644 --- a/js/web/lib/onnxjs/backend.ts +++ b/js/web/lib/onnxjs/backend.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {WebGLBackend} from './backends/backend-webgl'; +import {WebGpuBackend} from './backends/backend-webgpu'; import {Graph} from './graph'; import {Operator} from './operators'; import {OpSet} from './opset'; @@ -79,6 +80,7 @@ const backendsCache: Map = new Map(); export const backend: {[name: string]: Backend} = { webgl: new WebGLBackend(), + webgpu: new WebGpuBackend() }; /** diff --git a/js/web/lib/onnxjs/backends/backend-webgpu.ts b/js/web/lib/onnxjs/backends/backend-webgpu.ts new file mode 100644 index 0000000000000..6919571e83b5a --- /dev/null +++ b/js/web/lib/onnxjs/backends/backend-webgpu.ts @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; +import {Backend, SessionHandler} from '../backend'; +import {Logger} from '../instrument'; +import {Session} from '../session'; + +import {WebGpuSessionHandler} from './webgpu/session-handler'; + +export class WebGpuBackend implements Backend { + initialize(): boolean { + try { + // STEP.1 TODO: set up context (one time initialization) + + // STEP.2 TODO: set up flags + + Logger.setWithEnv(env); + + Logger.verbose('WebGpuBackend', 'Initialized successfully.'); + return true; + } catch (e) { + Logger.warning('WebGpuBackend', `Unable to initialize WebGLBackend. ${e}`); + return false; + } + } + createSessionHandler(context: Session.Context): SessionHandler { + return new WebGpuSessionHandler(this, context); + } + dispose(): void { + // TODO: uninitialization + // this.glContext.dispose(); + } +} diff --git a/js/web/lib/onnxjs/backends/webgpu/inference-handler.ts b/js/web/lib/onnxjs/backends/webgpu/inference-handler.ts new file mode 100644 index 0000000000000..491572cac1863 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/inference-handler.ts @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceHandler} from '../../backend'; + +import {WebGpuSessionHandler} from './session-handler'; + +export class WebGpuInferenceHandler implements InferenceHandler { + constructor(public session: WebGpuSessionHandler) { + // TODO: + } + + dispose(): void {} +} diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts new file mode 100644 index 0000000000000..d05b19d90a6b4 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {OpSet} from '../../opset'; + +export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ + // ['Abs', '', '6+', unaryOps.abs], + // ['Acos', '', '7+', unaryOps.acos], + // ['Add', '', '7+', binaryOps.add], + // ['And', '', '7+', binaryOps.and], + // ['Asin', '', '7+', unaryOps.asin], + // ['Atan', '', '7+', unaryOps.atan], + // // TODO: support new attributes for AveragePool-10 + // ['AveragePool', '', '7+', averagePool, parseAveragePoolAttributes], + // ['BatchNormalization', '', '7+', batchNormalization, parseBatchNormalizationAttributes], + // ['Cast', '', '6+', cast, parseCastAttributes], + // ['Ceil', '', '6+', unaryOps.ceil], + // ['Clip', '', '6-10', unaryOps.clip, unaryOps.parseClipAttributes], + // ['Clip', '', '11+', unaryOps.clipV11], + // ['Concat', '', '4+', concat, parseConcatAttributes], + // ['Conv', '', '1+', conv, parseConvAttributes], + // ['Cos', '', '7+', unaryOps.cos], + // ['Div', '', '7+', binaryOps.div], + // ['Dropout', '', '7+', unaryOps.identity], + // ['DepthToSpace', '', '1+', depthToSpace, parseDepthToSpaceAttributes], + // ['Equal', '', '7+', binaryOps.equal], + // ['Elu', '', '6+', unaryOps.elu, unaryOps.parseEluAttributes], + // ['Exp', '', '6+', unaryOps.exp], + // ['Flatten', '', '1+', flatten, parseFlattenAttributes], + // ['Floor', '', '6+', unaryOps.floor], + // ['FusedConv', 'com.microsoft', '1+', conv, parseConvAttributes], + // ['Gather', '', '1+', gather, parseGatherAttributes], + // ['Gemm', '', '7-10', gemm, parseGemmAttributesV7], + // ['Gemm', '', '11+', gemm, parseGemmAttributesV11], + // ['GlobalAveragePool', '', '1+', globalAveragePool, parseGlobalAveragePoolAttributes], + // ['GlobalMaxPool', '', '1+', globalMaxPool], + // ['Greater', '', '7+', binaryOps.greater], + // ['Identity', '', '1+', unaryOps.identity], + // ['ImageScaler', '', '1+', imageScaler, parseImageScalerAttributes], + // ['InstanceNormalization', '', '6+', instanceNormalization, parseInstanceNormalizationAttributes], + // ['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes], + // ['Less', '', '7+', binaryOps.less], + // ['Log', '', '6+', unaryOps.log], + // ['MatMul', '', '1+', matMul, parseMatMulAttributes], + // // TODO: support new attributes for MaxPool-8 and MaxPool-10 + // ['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes], + // ['Mul', '', '7+', binaryOps.mul], + // ['Neg', '', '6+', unaryOps.neg], + // ['Not', '', '1+', unaryOps.not], + // ['Or', '', '7+', binaryOps.or], + // ['Pad', '', '2-10', padV2, parsePadAttributesV2], + // ['Pad', '', '11+', padV11, parsePadAttributesV11], + // ['Pow', '', '7+', binaryOps.pow], + // ['PRelu', '', '7+', binaryOps.pRelu], + // ['ReduceLogSum', '', '1+', reduceLogSum, parseReduceAttributes], + // ['ReduceMax', '', '1+', reduceMax, parseReduceAttributes], + // ['ReduceMean', '', '1+', reduceMean, parseReduceAttributes], + // ['ReduceMin', '', '1+', reduceMin, parseReduceAttributes], + // ['ReduceProd', '', '1+', reduceProd, parseReduceAttributes], + // ['ReduceSum', '', '1-12', reduceSum, parseReduceAttributes], + // ['ReduceSumSquare', '', '1+', reduceLogSumSquare, parseReduceAttributes], + // ['Relu', '', '6+', unaryOps.relu], + // ['Reshape', '', '5+', reshape], + // ['Resize', '', '10', resize, parseResizeAttributesV10], + // ['Resize', '', '11+', resize, parseResizeAttributesV11], + // ['Shape', '', '1+', shape], + // ['Sigmoid', '', '6+', unaryOps.sigmoid], + // ['Sin', '', '7+', unaryOps.sin], + // ['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10 + // ['Slice', '', '1-9', slice, parseSliceAttributes], + // // The "semantic" meaning of axis has changed in opset-13. + // ['Softmax', '', '1-12', softmax, parseSoftmaxAttributes], + // ['Softmax', '', '13+', softmaxV13, parseSoftmaxAttributesV13], + // // 'Split' operator has an optional attribute 'split' + // // this attribute determines how the specified axis of input data is split. + // // When the attribute is missing, we need the count of number of outputs + // // so that we can determine the 'split' attribute from the runtime input to the Operator + // ['Split', '', '2-12', split, parseSplitAttributes], + // ['Sqrt', '', '6+', unaryOps.sqrt], + // ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes], + // ['Squeeze', '', '13+', squeezeV13], + // ['Sub', '', '7+', binaryOps.sub], + // ['Sum', '', '6+', sum], + // ['Tan', '', '7+', unaryOps.tan], + // ['Tanh', '', '6+', unaryOps.tanh], + // ['Tile', '', '6+', tile], + // ['Transpose', '', '1+', transpose, parseTransposeAttributes], + // ['Upsample', '', '7-8', upsample, parseUpsampleAttributesV7], + // ['Upsample', '', '9', upsample, parseUpsampleAttributesV9], + // ['Unsqueeze', '', '1-12', unsqueeze, parseUnsqueezeAttributes], + // ['Unsqueeze', '', '13+', unsqueezeV13], + // ['Xor', '', '7+', binaryOps.xor], +]; diff --git a/js/web/lib/onnxjs/backends/webgpu/session-handler.ts b/js/web/lib/onnxjs/backends/webgpu/session-handler.ts new file mode 100644 index 0000000000000..d65d67ab61a57 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/session-handler.ts @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {SessionHandler} from '../../backend'; +import {Graph} from '../../graph'; +import {Operator} from '../../operators'; +import {OpSet, resolveOperator} from '../../opset'; +import {Session} from '../../session'; +import {Tensor} from '../../tensor'; +import {WebGpuBackend} from '../backend-webgpu'; +import {WebGpuInferenceHandler} from './inference-handler'; + +import {WEBGPU_OP_RESOLVE_RULES} from './op-resolve-rules'; + +export class WebGpuSessionHandler implements SessionHandler { + private initializers: Set; + + constructor(public readonly backend: WebGpuBackend, public readonly context: Session.Context) { + // TODO + } + + createInferenceHandler() { + return new WebGpuInferenceHandler(this); + } + onGraphInitialized(graph: Graph): void { + const initializers = graph.getValues().filter(v => v.from === -1 && v.tensor).map(v => v.tensor!.dataId); + this.initializers = new Set(initializers); + } + isInitializer(tensorId: Tensor.Id): boolean { + return this.initializers ? this.initializers.has(tensorId) : false; + } + addInitializer(tensorId: Tensor.Id): void { + this.initializers.add(tensorId); + } + dispose(): void { + // TODO + } + resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator { + const op = resolveOperator(node, opsets, WEBGPU_OP_RESOLVE_RULES); + return {impl: op.opImpl, context: op.opInit ? op.opInit(node, graph) : node}; + } +} diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index c390b330e8252..4c8c098155898 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -34,6 +34,7 @@ Options: -b=<...>, --backend=<...> Specify one or more backend(s) to run the test upon. Backends can be one or more of the following, splitted by comma: webgl + webgpu wasm -e=<...>, --env=<...> Specify the environment to run the test. Should be one of the following: chrome (default) @@ -97,7 +98,7 @@ Examples: export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; - type Backend = 'cpu'|'webgl'|'wasm'|'onnxruntime'; + type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; } @@ -333,7 +334,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } // Option: -b=<...>, --backend=<...> - const browserBackends = ['webgl', 'wasm']; + const browserBackends = ['webgl', 'webgpu', 'wasm']; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; const backend =