From 99c51cafe9790b8261884487c280de1d6e85c859 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 19 Mar 2024 12:55:00 -0700 Subject: [PATCH] [js/webgpu] allow setting env.webgpu.adapter (#19940) ### Description Allow user to set `env.webgpu.adapter` before creating the first inference session. Feature request: https://github.com/microsoft/onnxruntime/pull/19857#issuecomment-1999984753 @xenova --- js/common/lib/env.ts | 10 +++++--- js/web/lib/wasm/jsep/backend-webgpu.ts | 6 +++-- js/web/lib/wasm/wasm-core-impl.ts | 35 ++++++++++++++++++-------- 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index b139c719e863f..c8df1613b3268 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -166,16 +166,20 @@ export declare namespace Env { */ forceFallbackAdapter?: boolean; /** - * Get the adapter for WebGPU. + * Set or get the adapter for WebGPU. * - * This property is only available after the first WebGPU inference session is created. + * Setting this property only has effect before the first WebGPU inference session is created. The value will be + * used as the GPU adapter for the underlying WebGPU backend to create GPU device. + * + * If this property is not set, it will be available to get after the first WebGPU inference session is created. The + * value will be the GPU adapter that created by the underlying WebGPU backend. * * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types". * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type. * * see comments on {@link Tensor.GpuBufferType} */ - readonly adapter: unknown; + adapter: unknown; /** * Get the device for WebGPU. * diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index d92b8ac68dbe7..b36dc73330d46 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -252,8 +252,10 @@ export class WebGpuBackend { } }; - Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); - Object.defineProperty(this.env.webgpu, 'adapter', {value: adapter}); + Object.defineProperty( + this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false}); + Object.defineProperty( + this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false}); // init queryType, which is necessary for InferenceSession.create this.setQueryType(); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 7019758be0efd..9b27051f1b9fe 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -93,18 +93,31 @@ export const initEp = async(env: Env, epName: string): Promise<void> => { if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); } - const powerPreference = env.webgpu?.powerPreference; - if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') { - throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); - } - const forceFallbackAdapter = env.webgpu?.forceFallbackAdapter; - if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { - throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); - } - const adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); + + let adapter = env.webgpu.adapter as GPUAdapter | null; if (!adapter) { - throw new Error( - 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + // if adapter is not set, request a new adapter. + const powerPreference = env.webgpu.powerPreference; + if (powerPreference !== undefined && powerPreference !== 'low-power' && + powerPreference !== 'high-performance') { + throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); + } + const forceFallbackAdapter = env.webgpu.forceFallbackAdapter; + if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { + throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); + } + adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); + if (!adapter) { + throw new Error( + 'Failed to get GPU adapter. ' + + 'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + } + } else { + // if adapter is set, validate it. + if (typeof adapter.limits !== 'object' || typeof adapter.features !== 'object' || + typeof adapter.requestDevice !== 'function') { + throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.'); + } } if (!env.wasm.simd) {