From 6edc6ea87ad378468d26396854cf835a0f49e823 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 30 Oct 2024 08:10:14 +0800 Subject: [PATCH] [js/webgpu] Optimize InstanceNorm in some shapes (#22637) BUG #22031 Optimize below two situations: 1. Increase workgroupSize if only one workgroup is dispatched. 2. Avoid transpose if not necessary. The overall time of demucs model becomes 106.36 ms from 154.60 ms on my dGPUs with this PR and PR #22577 --- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 859bd850862aa..a357d29667319 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -36,7 +36,10 @@ const computeChannelScaleShift = ( const f32Type = components === 1 ? 'f32' : `vec${components}f`; const wgType = components === 1 ? 'vec2f' : `mat2x${components}f`; const unitsOfWork = n * c; - + let workgroupSize = 64; + if (unitsOfWork === 1) { + workgroupSize = 256; + } const inputShape = [n, c, h / components]; const outputShape = [n, c, 2]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; @@ -49,7 +52,6 @@ const computeChannelScaleShift = ( const b = inputVariable('bias', bias.dataType, bias.dims); const output = outputVariable('output', DataType.float, 3, 2); const variables = [x, s, b, output]; - const workgroupSize = 64; return ` var workgroup_shared : array<${wgType}, ${workgroupSize}>; const workgroup_size = ${workgroupSize}u; @@ -91,7 +93,7 @@ const computeChannelScaleShift = ( { name: 'InstanceNormComputeChannelScaleShift', // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: { hint: `${components};${epsilon}`, inputDependencies }, + shaderCache: { hint: `${components};${epsilon};${workgroupSize}`, inputDependencies }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: DataType.float }], dispatchGroup: { x: unitsOfWork }, @@ -187,14 +189,21 @@ const createInstanceNormNHWCProgramInfo = ( const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // 1. transpose x from NHWC to NCHW + let needTranspose = false; const transposedXPerm = [0, xShape.length - 1]; for (let i = 0; i < xShape.length - 2; i++) { + needTranspose = needTranspose || xShape[i + 1] !== 1; transposedXPerm.push(i + 1); } - const transposedX = context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { - inputs: [context.inputs[0]], - outputs: [-1], - })[0]; + + needTranspose = needTranspose && xShape[xShape.length - 1] !== 1; + + const transposedX = needTranspose + ? context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { + inputs: [context.inputs[0]], + outputs: [-1], + })[0] + : context.inputs[0].reshape(Array.from({ length: xShape.length }, (_, i) => xShape[transposedXPerm[i]])); // 2. compute channel scale and channel shift. const channelScaleShift = computeChannelScaleShift( context,