diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 182c1cd351c9d..d92b8ac68dbe7 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -10,7 +10,7 @@ import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; +import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; interface CommandInfo { readonly kernelId: number; @@ -94,11 +94,32 @@ const getProgramInfoUniqueKey = return key; }; +class AdapterInfoImpl implements AdapterInfo { + readonly architecture?: string; + readonly vendor?: string; + + constructor(adapterInfo: GPUAdapterInfo) { + if (adapterInfo) { + this.architecture = adapterInfo.architecture; + this.vendor = adapterInfo.vendor; + } + } + + isArchitecture(architecture: GpuArchitecture): boolean { + return this.architecture === architecture; + } + + isVendor(vendor: GpuVendor): boolean { + return this.vendor === vendor; + } +} + /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. */ export class WebGpuBackend { + adapterInfo: AdapterInfoImpl; device: GPUDevice; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping @@ -212,6 +233,7 @@ export class WebGpuBackend { } this.device = await adapter.requestDevice(deviceDescriptor); + this.adapterInfo = new AdapterInfoImpl(await adapter.requestAdapterInfo()); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); this.kernels = new Map(); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b64abf9cc5424..4936b94ef7a86 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -10,7 +10,7 @@ import {WebGpuBackend} from './backend-webgpu'; import {LOG_DEBUG} from './log'; import {TensorView} from './tensor-view'; import {ShapeUtil} from './util'; -import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; +import {AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; /* eslint-disable no-bitwise */ @@ -54,6 +54,7 @@ class TensorViewImpl implements TensorView { } class ComputeContextImpl implements ComputeContext { + readonly adapterInfo: AdapterInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -66,6 +67,7 @@ class ComputeContextImpl implements ComputeContext { private customDataOffset = 0; private customDataSize = 0; constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { + this.adapterInfo = backend.adapterInfo; const heapU32 = module.HEAPU32; // extract context data diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 5afec0389fac8..b68d4dcae4cb9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -148,11 +148,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ const isChannelsLast = attributes.format === 'NHWC'; if (attributes.group !== 1) { - // Temporarily disable createGroupedConvVectorizeProgramInfo path due to bots failures with below two cases: + // NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other + // GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs. // [webgpu]Conv - conv - vectorize group - B // [webgpu]Conv - conv - vectorize group - D - const disableGroupedConvVectorize = true; - if (!disableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && + const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere'); + if (enableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) { const outputShape = calculateOutputShape( inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index ba5b84fcfe067..48e0855f01a97 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -15,6 +15,13 @@ export enum GpuDataType { } export type GpuDataId = number; +export type GpuArchitecture = 'ampere'; +export type GpuVendor = 'amd'|'intel'|'nvidia'; +export interface AdapterInfo { + isArchitecture: (architecture: GpuArchitecture) => boolean; + isVendor: (vendor: GpuVendor) => boolean; +} + export interface GpuData { type: GpuDataType; id: GpuDataId; @@ -146,6 +153,11 @@ export interface ComputeContextInputsOutputsMapping { * A ComputeContext instance carries the states that representing the current running of a kernel. */ export interface ComputeContext { + /** + * gpu adapter info + */ + readonly adapterInfo: AdapterInfo; + /** * stores the pointer to OpKernelContext */