From d226e40856738531cf8b481b07379545f7cfefe2 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 23 Jan 2024 08:08:55 +0800 Subject: [PATCH] [js/webgpu] set query type in onRunStart (#19202) ### Description `env.webgpu.profiling` is a global flag. It may change before each session.run. So the best place is to update it in `onRunStart` event. After this, we can directly check `this.queryType`'s value. Without this pr, we need to make sure that `getCommandEncoder()` is called before checking `this.queryType`. Otherwise, it may happen that `pendingKernels`'s length is not equal to `pendingDispatchNumber`'s length. See the two ugly workarounds [1)](https://github.com/microsoft/onnxruntime/pull/18989/commits/e630dbf528fc3a955702cceb968930d0abdfc652#diff-006fc84d3997f96a29b8033bd2075d6a0a9509211bd5812a6b934fc74fedfd9dR267-R268) and [2)](https://github.com/microsoft/onnxruntime/pull/18989/commits/e630dbf528fc3a955702cceb968930d0abdfc652#diff-618fe297fbe7a1da586380163b8fd2627311ccc217640a3c5cdc9c17a33472c1R73-R80) if we don't introduce `onRunStart`. Or we need to call `setQueryType` in each kernel run. --- js/web/lib/wasm/binding/ort-wasm.d.ts | 4 ++++ js/web/lib/wasm/jsep/backend-webgpu.ts | 9 +++++---- js/web/lib/wasm/wasm-core-impl.ts | 2 +- onnxruntime/wasm/js_internal_api.js | 3 +++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 9d4d5875310b7..68054210e79a7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule { jsepCreateDownloader: (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Called when InferenceSession.run started. + */ + jsepOnRunStart: () => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 2956ec1cad4da..afef7042a4280 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -208,7 +208,7 @@ export class WebGpuBackend { Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); - // init queryType, which is necessary for createKernel + // init queryType, which is necessary for InferenceSession.create this.setQueryType(); } @@ -223,8 +223,6 @@ export class WebGpuBackend { if (!this.commandEncoder) { this.commandEncoder = this.device.createCommandEncoder(); - // refresh queryType, as sometimes we only need to enable query for a specific run - this.setQueryType(); if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { this.querySet = this.device.createQuerySet({ type: 'timestamp', @@ -639,6 +637,7 @@ export class WebGpuBackend { return createView(data.buffer, type); }; } + // #endregion writeTimestamp(index: number): void { if (this.queryType !== 'inside-passes') { return; @@ -657,5 +656,7 @@ export class WebGpuBackend { } } } - // #endregion + onRunStart(): void { + this.setQueryType(); + } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5821fac3c468f..8768643fa7257 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -488,8 +488,8 @@ export const run = async( } } + wasm.jsepOnRunStart?.(); let errorCode: number; - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 25ece9c700d5d..7c70515e73eab 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -186,4 +186,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; + Module['jsepOnRunStart'] = () => { + return backend['onRunStart'](); + }; };