From e60587078320b06dba9338d5cc00a41a0a85068b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Dec 2024 10:24:14 -0800 Subject: [PATCH] [js/web] Update API for `ort.env.webgpu` (#23026) ### Description This PR is a replacement of #21671. It offers a new way for accessing the following: - `ort.env.webgpu.adapter`: - **deprecating**. There is no point to get the value of it. Once `GPUDevice.adapterInfo` is supported, there is no point to set the value too. - `ort.env.webgpu.device`: - set value of `GPUDevice` if user created it. Use at user's own risk. - get value of `Promise`. if not exist, create a new one. if exist return it. - `ort.env.webgpu.powerPreference`: - **deprecating**. encouraging users to set `ort.env.webgpu.device` if necessary. - `ort.env.webgpu.forceFallbackAdapter`: - **deprecating**. encouraging users to set `ort.env.webgpu.device` if necessary. --- js/common/lib/env.ts | 33 +++++++++++++++++++++++++-------- js/web/test/test-runner.ts | 12 ++++++------ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index e70f608ad7030..d6d9f7fa48790 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -45,17 +45,19 @@ export declare namespace Env { * * This setting is available only when WebAssembly SIMD feature is available in current context. * + * @defaultValue `true` + * * @deprecated This property is deprecated. Since SIMD is supported by all major JavaScript engines, non-SIMD * build is no longer provided. This property will be removed in future release. - * @defaultValue `true` */ simd?: boolean; /** * set or get a boolean value indicating whether to enable trace. * - * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. * @defaultValue `false` + * + * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. */ trace?: boolean; @@ -153,7 +155,7 @@ export declare namespace Env { /** * Set or get the profiling configuration. */ - profiling?: { + profiling: { /** * Set or get the profiling mode. * @@ -176,6 +178,9 @@ export declare namespace Env { * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. * * @defaultValue `undefined` + * + * @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if + * you want to use a specific power preference. */ powerPreference?: 'low-power' | 'high-performance'; /** @@ -187,6 +192,9 @@ export declare namespace Env { * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. * * @defaultValue `undefined` + * + * @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if + * you want to use a specific fallback option. */ forceFallbackAdapter?: boolean; /** @@ -199,16 +207,25 @@ export declare namespace Env { * value will be the GPU adapter that created by the underlying WebGPU backend. * * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types". + * + * @deprecated It is no longer recommended to use this property. The latest WebGPU spec adds `GPUDevice.adapterInfo` + * (https://www.w3.org/TR/webgpu/#dom-gpudevice-adapterinfo), which allows to get the adapter information from the + * device. When it's available, there is no need to set/get the {@link adapter} property. */ adapter: TryGetGlobalType<'GPUAdapter'>; /** - * Get the device for WebGPU. - * - * This property is only available after the first WebGPU inference session is created. + * Set or get the GPU device for WebGPU. * - * When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types". + * There are 3 valid scenarios of accessing this property: + * - Set a value before the first WebGPU inference session is created. The value will be used by the WebGPU backend + * to perform calculations. If the value is not a `GPUDevice` object, an error will be thrown. + * - Get the value before the first WebGPU inference session is created. This will try to create a new GPUDevice + * instance. Returns a `Promise` that resolves to a `GPUDevice` object. + * - Get the value after the first WebGPU inference session is created. Returns a resolved `Promise` to the + * `GPUDevice` object used by the WebGPU backend. */ - readonly device: TryGetGlobalType<'GPUDevice'>; + get device(): Promise>; + set device(value: TryGetGlobalType<'GPUDevice'>); /** * Set or get whether validate input content. * diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index d54ba32f9f494..5de39535a5c07 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -586,11 +586,11 @@ export class TensorResultValidator { } } -function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { +async function createGpuTensorForInput(cpuTensor: ort.Tensor): Promise { if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`); } - const device = ort.env.webgpu.device as GPUDevice; + const device = await ort.env.webgpu.device; const gpuBuffer = device.createBuffer({ // eslint-disable-next-line no-bitwise usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, @@ -612,14 +612,14 @@ function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { }); } -function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { +async function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { if (!isGpuBufferSupportedType(type)) { throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`); } const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!; - const device = ort.env.webgpu.device as GPUDevice; + const device = await ort.env.webgpu.device; const gpuBuffer = device.createBuffer({ // eslint-disable-next-line no-bitwise usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, @@ -725,7 +725,7 @@ export async function sessionRun(options: { if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') { feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]); } else { - feeds[name] = createGpuTensorForInput(feeds[name]); + feeds[name] = await createGpuTensorForInput(feeds[name]); } } } @@ -742,7 +742,7 @@ export async function sessionRun(options: { if (options.ioBinding === 'ml-tensor') { fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims); } else { - fetches[name] = createGpuTensorForOutput(type, dims); + fetches[name] = await createGpuTensorForOutput(type, dims); } } }