From d673e39ad89a709d5896510bcd496927567b4b79 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Mon, 11 Dec 2023 20:58:52 -0800 Subject: [PATCH] [JS/WebGPU] Added uniforms to Tile and Where Ops (#18768) ### Description Added uniforms to Tile and Where Ops ### Motivation and Context Improve performance. --- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 27 ++++++----- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 59 +++++++++++++----------- 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index e294541a775ca..90a36a7bec2a9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; const getRepeats = (repeatsTensorView: TensorView): readonly number[] => Array.from(repeatsTensorView.getBigInt64Array(), Number); @@ -54,30 +54,35 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf const outputSize = ShapeUtil.size(outputShape); const dataType = inputs[0].dataType; - const input = inputVariable('input', dataType, inputShape); - const output = outputVariable('output', dataType, outputShape); + const input = inputVariable('input', dataType, inputShape.length); + const output = outputVariable('output', dataType, outputShape.length); const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; - ${shaderHelper.declareVariables(input, output)} + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let output_indices = ${output.offsetToIndices('global_idx')}; + var input_indices: ${input.type.indices}; for (var i = 0; i < ${inputShape.length}; i++) { - let inputDimValue = ${output.indicesGet('outputIndices', 'i')} % ${input.indicesGet('inputShape', 'i')}; + let input_dim_i = ${input.indicesGet('uniforms.input_shape', 'i')}; + let input_dim_value = ${output.indicesGet('output_indices', 'i')} % input_dim_i; - ${input.indicesSet('inputIndices', 'i', 'inputDimValue')} + ${input.indicesSet('input_indices', 'i', 'input_dim_value')} } - ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} + ${output.setByOffset('global_idx', input.getByIndices('input_indices'))} }`; return { name: 'Tile', - shaderCache: {hint: `${repeats}`}, + shaderCache: {hint: `${repeats}`, inputDependencies: ['rank']}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: [ + {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(outputShape) + ], }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 6f66dd86b4088..687ee054096cc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -6,18 +6,15 @@ import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; const createWhereOpProgramShader = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean, typeOutput: number) => { - const outputSize = ShapeUtil.size(dimsOutput); - const vecSize = Math.ceil(outputSize / 4); - - const output = outputVariable('outputData', typeOutput, dimsOutput, 4); - const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4); - const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4); - const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4); + const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4); + const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4); + const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4); + const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4); let assignment: string; const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; @@ -27,20 +24,20 @@ const createWhereOpProgramShader = expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx'))); } else { const singleAssignment = (resStr: string, x: number, typeCast = '') => { - const expressionA = `aData[indexA${x}][componentA${x}]`; - const expressionB = `bData[indexB${x}][componentB${x}]`; + const expressionA = `a_data[index_a${x}][component_a${x}]`; + const expressionB = `b_data[index_b${x}][component_b${x}]`; // eslint-disable-next-line no-bitwise - const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; return ` - let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; - let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; - let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; - let indexA${x} = offsetA${x} / 4u; - let indexB${x} = offsetB${x} / 4u; - let indexC${x} = offsetC${x} / 4u; - let componentA${x} = offsetA${x} % 4u; - let componentB${x} = offsetB${x} % 4u; + let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; + let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; + let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)}; + let offset_c${x} = ${c.broadcastedIndicesToOffset(`output_indices${x}`, output)}; + let index_a${x} = offset_a${x} / 4u; + let index_b${x} = offset_b${x} / 4u; + let index_c${x} = offset_c${x} / 4u; + let component_a${x} = offset_a${x} % 4u; + let component_b${x} = offset_b${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; }; @@ -51,21 +48,21 @@ const createWhereOpProgramShader = ${singleAssignment('data', 1, 'u32')} ${singleAssignment('data', 2, 'u32')} ${singleAssignment('data', 3, 'u32')} - outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; + output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; } else { assignment = ` - ${singleAssignment('outputData[global_idx]', 0)} - ${singleAssignment('outputData[global_idx]', 1)} - ${singleAssignment('outputData[global_idx]', 2)} - ${singleAssignment('outputData[global_idx]', 3)} + ${singleAssignment('output_data[global_idx]', 0)} + ${singleAssignment('output_data[global_idx]', 1)} + ${singleAssignment('output_data[global_idx]', 2)} + ${singleAssignment('output_data[global_idx]', 3)} `; } } return ` - ${shaderHelper.declareVariables(c, a, b, output)} + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} ${assignment} }`; }; @@ -79,6 +76,7 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); let outputShape = dimsA; let outputSize = ShapeUtil.size(dimsA); + const vecSize = Math.ceil(outputSize / 4); // TODO: deal with zero-sized tensors (eg. dims=[1,0]) if (isBroadcast) { @@ -92,11 +90,16 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => return { name: 'Where', + shaderCache: {inputDependencies: ['rank', 'rank', 'rank']}, getShaderSource: (shaderHelper) => createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, + programUniforms: [ + {type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA), + ...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape) + ], }), }; };