diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 2c8a1966cd0b3..f21111cccfee7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {castToF32, createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -98,36 +98,30 @@ const createSkipLayerNormProgramInfo = {type: 'uint32', data: hiddenSize}, {type: 'float32', data: attributes.epsilon}, ]; - inputs.forEach((input, _) => { - programUniforms.push(...createTensorShapeVariables(input.dims)); - }); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components), - inputVariable('skip', inputs[1].dataType, inputs[1].dims.length, components), - inputVariable('gamma', inputs[2].dataType, inputs[2].dims.length, components), - ]; - if (hasBetaInput) { - variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims.length, components)); - } - if (hasBiasInput) { - variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims.length, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape.length, components)); - programUniforms.push(...createTensorShapeVariables(outputShape)); - if (hasMeanOutput) { - variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim.length)); - programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim)); - } - if (hasInvStdDevOutput) { - variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim.length)); - programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim)); - } - if (hasInputSkipBiasSumOutput) { - variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape.length, components)); - programUniforms.push(...createTensorShapeVariables(outputShape)); - } - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` const epsilon: f32 = ${attributes.epsilon}; ${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)} @@ -151,14 +145,15 @@ const createSkipLayerNormProgramInfo = } let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size); let inv_std_dev = inverseSqrt(${ - sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon); + sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon); ${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''} ${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''} for (var i: u32 = 0; i < hidden_size_vectorized; i++) { output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${ - hasBetaInput ? 'beta[i]' : '0.0'}; + hasBetaInput ? 'beta[i]' : '0.0'}; } }`; + }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (outputCount > 1) { outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); @@ -173,7 +168,7 @@ const createSkipLayerNormProgramInfo = name: 'SkipLayerNormalization', shaderCache: { hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`, - inputDependencies: inputs.map((_input, _index) => 'rank') + inputDependencies: inputs.map((_input, _index) => 'type') }, getShaderSource, getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}),