Skip to content

Commit

Permalink
Removed unnecessary shapes/strides in the shader.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Jan 19, 2024
1 parent 687b266 commit dca54cf
Showing 1 changed file with 29 additions and 34 deletions.
63 changes: 29 additions & 34 deletions js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {castToF32, createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';

export interface SkipLayerNormAttributes extends AttributeWithCacheKey {
epsilon: number;
Expand Down Expand Up @@ -98,36 +98,30 @@ const createSkipLayerNormProgramInfo =
{type: 'uint32', data: hiddenSize},
{type: 'float32', data: attributes.epsilon},
];
inputs.forEach((input, _) => {
programUniforms.push(...createTensorShapeVariables(input.dims));
});
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components),
inputVariable('skip', inputs[1].dataType, inputs[1].dims.length, components),
inputVariable('gamma', inputs[2].dataType, inputs[2].dims.length, components),
];
if (hasBetaInput) {
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims.length, components));
}
if (hasBiasInput) {
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims.length, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape.length, components));
programUniforms.push(...createTensorShapeVariables(outputShape));
if (hasMeanOutput) {
variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim.length));
programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim));
}
if (hasInvStdDevOutput) {
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim.length));
programUniforms.push(...createTensorShapeVariables(meanInvStdDevDim));
}
if (hasInputSkipBiasSumOutput) {
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape.length, components));
programUniforms.push(...createTensorShapeVariables(outputShape));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const getShaderSource = (shaderHelper: ShaderHelper) => `
const getShaderSource = (shaderHelper: ShaderHelper) => {
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
];
if (hasBetaInput) {
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
}
if (hasBiasInput) {
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
if (hasMeanOutput) {
variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim));
}
if (hasInvStdDevOutput) {
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim));
}
if (hasInputSkipBiasSumOutput) {
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return `
const epsilon: f32 = ${attributes.epsilon};
${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)}
Expand All @@ -151,14 +145,15 @@ const createSkipLayerNormProgramInfo =
}
let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size);
let inv_std_dev = inverseSqrt(${
sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''}
${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''}
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${
hasBetaInput ? 'beta[i]' : '0.0'};
hasBetaInput ? 'beta[i]' : '0.0'};
}
}`;
};
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
if (outputCount > 1) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
Expand All @@ -173,7 +168,7 @@ const createSkipLayerNormProgramInfo =
name: 'SkipLayerNormalization',
shaderCache: {
hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`,
inputDependencies: inputs.map((_input, _index) => 'rank')
inputDependencies: inputs.map((_input, _index) => 'type')
},
getShaderSource,
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}),
Expand Down

0 comments on commit dca54cf

Please sign in to comment.