Skip to content

Commit

Permalink
[JS/WebGPU] Added uniforms to Tile and Where Ops (#18768)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Added uniforms to Tile and Where Ops


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Improve performance.
  • Loading branch information
satyajandhyala authored Dec 12, 2023
1 parent b4be9e1 commit d673e39
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 39 deletions.
27 changes: 16 additions & 11 deletions js/web/lib/wasm/jsep/webgpu/ops/tile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

const getRepeats = (repeatsTensorView: TensorView): readonly number[] =>
Array.from(repeatsTensorView.getBigInt64Array(), Number);
Expand Down Expand Up @@ -54,30 +54,35 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf
const outputSize = ShapeUtil.size(outputShape);

const dataType = inputs[0].dataType;
const input = inputVariable('input', dataType, inputShape);
const output = outputVariable('output', dataType, outputShape);
const input = inputVariable('input', dataType, inputShape.length);
const output = outputVariable('output', dataType, outputShape.length);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputShape = ${input.indices(...inputShape)};
${shaderHelper.declareVariables(input, output)}
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};
var inputIndices: ${input.type.indices};
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let output_indices = ${output.offsetToIndices('global_idx')};
var input_indices: ${input.type.indices};
for (var i = 0; i < ${inputShape.length}; i++) {
let inputDimValue = ${output.indicesGet('outputIndices', 'i')} % ${input.indicesGet('inputShape', 'i')};
let input_dim_i = ${input.indicesGet('uniforms.input_shape', 'i')};
let input_dim_value = ${output.indicesGet('output_indices', 'i')} % input_dim_i;
${input.indicesSet('inputIndices', 'i', 'inputDimValue')}
${input.indicesSet('input_indices', 'i', 'input_dim_value')}
}
${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
${output.setByOffset('global_idx', input.getByIndices('input_indices'))}
}`;

return {
name: 'Tile',
shaderCache: {hint: `${repeats}`},
shaderCache: {hint: `${repeats}`, inputDependencies: ['rank']},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims),
...createTensorShapeVariables(outputShape)
],
}),
getShaderSource,
};
Expand Down
59 changes: 31 additions & 28 deletions js/web/lib/wasm/jsep/webgpu/ops/where.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@ import {TensorView} from '../../tensor-view';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

const createWhereOpProgramShader =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean,
typeOutput: number) => {
const outputSize = ShapeUtil.size(dimsOutput);
const vecSize = Math.ceil(outputSize / 4);

const output = outputVariable('outputData', typeOutput, dimsOutput, 4);
const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4);
const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4);
const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4);
const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4);
const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4);
const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4);
const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4);

let assignment: string;
const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`;
Expand All @@ -27,20 +24,20 @@ const createWhereOpProgramShader =
expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')));
} else {
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
const expressionA = `aData[indexA${x}][componentA${x}]`;
const expressionB = `bData[indexB${x}][componentB${x}]`;
const expressionA = `a_data[index_a${x}][component_a${x}]`;
const expressionB = `b_data[index_b${x}][component_b${x}]`;
// eslint-disable-next-line no-bitwise
const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
return `
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let indexC${x} = offsetC${x} / 4u;
let componentA${x} = offsetA${x} % 4u;
let componentB${x} = offsetB${x} % 4u;
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)};
let offset_c${x} = ${c.broadcastedIndicesToOffset(`output_indices${x}`, output)};
let index_a${x} = offset_a${x} / 4u;
let index_b${x} = offset_b${x} / 4u;
let index_c${x} = offset_c${x} / 4u;
let component_a${x} = offset_a${x} % 4u;
let component_b${x} = offset_b${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};
Expand All @@ -51,21 +48,21 @@ const createWhereOpProgramShader =
${singleAssignment('data', 1, 'u32')}
${singleAssignment('data', 2, 'u32')}
${singleAssignment('data', 3, 'u32')}
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
output_data[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
} else {
assignment = `
${singleAssignment('outputData[global_idx]', 0)}
${singleAssignment('outputData[global_idx]', 1)}
${singleAssignment('outputData[global_idx]', 2)}
${singleAssignment('outputData[global_idx]', 3)}
${singleAssignment('output_data[global_idx]', 0)}
${singleAssignment('output_data[global_idx]', 1)}
${singleAssignment('output_data[global_idx]', 2)}
${singleAssignment('output_data[global_idx]', 3)}
`;
}
}

return `
${shaderHelper.declareVariables(c, a, b, output)}
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
${assignment}
}`;
};
Expand All @@ -79,6 +76,7 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC));
let outputShape = dimsA;
let outputSize = ShapeUtil.size(dimsA);
const vecSize = Math.ceil(outputSize / 4);
// TODO: deal with zero-sized tensors (eg. dims=[1,0])

if (isBroadcast) {
Expand All @@ -92,11 +90,16 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>

return {
name: 'Where',
shaderCache: {inputDependencies: ['rank', 'rank', 'rank']},
getShaderSource: (shaderHelper) =>
createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType),
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)},
programUniforms: [
{type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA),
...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape)
],
}),
};
};
Expand Down

0 comments on commit d673e39

Please sign in to comment.