From 99c51cafe9790b8261884487c280de1d6e85c859 Mon Sep 17 00:00:00 2001
From: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Date: Tue, 19 Mar 2024 12:55:00 -0700
Subject: [PATCH] [js/webgpu] allow setting env.webgpu.adapter (#19940)

### Description
Allow user to set `env.webgpu.adapter` before creating the first
inference session.

Feature request:
https://github.com/microsoft/onnxruntime/pull/19857#issuecomment-1999984753

@xenova
---
 js/common/lib/env.ts                   | 10 +++++---
 js/web/lib/wasm/jsep/backend-webgpu.ts |  6 +++--
 js/web/lib/wasm/wasm-core-impl.ts      | 35 ++++++++++++++++++--------
 3 files changed, 35 insertions(+), 16 deletions(-)

diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts
index b139c719e863f..c8df1613b3268 100644
--- a/js/common/lib/env.ts
+++ b/js/common/lib/env.ts
@@ -166,16 +166,20 @@ export declare namespace Env {
      */
     forceFallbackAdapter?: boolean;
     /**
-     * Get the adapter for WebGPU.
+     * Set or get the adapter for WebGPU.
      *
-     * This property is only available after the first WebGPU inference session is created.
+     * Setting this property only has effect before the first WebGPU inference session is created. The value will be
+     * used as the GPU adapter for the underlying WebGPU backend to create GPU device.
+     *
+     * If this property is not set, it will be available to get after the first WebGPU inference session is created. The
+     * value will be the GPU adapter that created by the underlying WebGPU backend.
      *
      * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
      * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
      *
      * see comments on {@link Tensor.GpuBufferType}
      */
-    readonly adapter: unknown;
+    adapter: unknown;
     /**
      * Get the device for WebGPU.
      *
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index d92b8ac68dbe7..b36dc73330d46 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -252,8 +252,10 @@ export class WebGpuBackend {
       }
     };
 
-    Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
-    Object.defineProperty(this.env.webgpu, 'adapter', {value: adapter});
+    Object.defineProperty(
+        this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false});
+    Object.defineProperty(
+        this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false});
 
     // init queryType, which is necessary for InferenceSession.create
     this.setQueryType();
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index 7019758be0efd..9b27051f1b9fe 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -93,18 +93,31 @@ export const initEp = async(env: Env, epName: string): Promise<void> => {
       if (typeof navigator === 'undefined' || !navigator.gpu) {
         throw new Error('WebGPU is not supported in current environment');
       }
-      const powerPreference = env.webgpu?.powerPreference;
-      if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') {
-        throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
-      }
-      const forceFallbackAdapter = env.webgpu?.forceFallbackAdapter;
-      if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
-        throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
-      }
-      const adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter});
+
+      let adapter = env.webgpu.adapter as GPUAdapter | null;
       if (!adapter) {
-        throw new Error(
-            'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
+        // if adapter is not set, request a new adapter.
+        const powerPreference = env.webgpu.powerPreference;
+        if (powerPreference !== undefined && powerPreference !== 'low-power' &&
+            powerPreference !== 'high-performance') {
+          throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
+        }
+        const forceFallbackAdapter = env.webgpu.forceFallbackAdapter;
+        if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
+          throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
+        }
+        adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter});
+        if (!adapter) {
+          throw new Error(
+              'Failed to get GPU adapter. ' +
+              'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
+        }
+      } else {
+        // if adapter is set, validate it.
+        if (typeof adapter.limits !== 'object' || typeof adapter.features !== 'object' ||
+            typeof adapter.requestDevice !== 'function') {
+          throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.');
+        }
       }
 
       if (!env.wasm.simd) {