Skip to content

Commit

Permalink
Removed special case handling of all inputs zero-sized
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Mar 5, 2024
1 parent 51c83d7 commit 2bfc23e
Showing 1 changed file with 54 additions and 56 deletions.
110 changes: 54 additions & 56 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,43 +74,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
return codeLines.join('\n');
};

const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number, outputShape: number[]): ProgramInfo => {
const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);
const dataType = inputs[0].dataType;

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[axis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));
const createConcatProgramInfo =
(inputs: readonly TensorView[], axis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[axis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));

const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', axis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', axis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
${(() => {
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}
${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
Expand All @@ -132,41 +132,39 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number, ou
}
}`;

return {
name: 'Concat',
shaderCache: {hint: `${axis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};
return {
name: 'Concat',
shaderCache: {hint: `${axis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};

export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
// find a none zero tensor to determine the output shape
// Choose input with max rank if all input tensors are zero size to make the output shape independent of the order of
// the inputs.
let referenceIndex = context.inputs.findIndex(input => ShapeUtil.size(input.dims) > 0);
const inputs = context.inputs;
let referenceIndex = inputs.findIndex(input => ShapeUtil.size(input.dims) > 0);
if (referenceIndex === -1) {
referenceIndex = context.inputs.reduce(
referenceIndex = inputs.reduce(
(maxRankIndex, input, index, array) => input.dims > array[maxRankIndex].dims ? index : maxRankIndex, 0);
}

validateInputs(context.inputs, referenceIndex, attributes.axis);
const inputShape = context.inputs[referenceIndex].dims;
validateInputs(inputs, referenceIndex, attributes.axis);
const inputShape = inputs[referenceIndex].dims;
const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0);
const outputShape = inputShape.slice();
outputShape[adjustedAxis] =
context.inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
// 0 length tensors are valid for concat, remove them
const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
if (nonEmptyInputs.length > 0) {
context.compute(createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape), {inputs: nonEmptyInputs});
} else {
context.output(0, outputShape);
}
const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(
createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs});
};

export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
Expand Down

0 comments on commit 2bfc23e

Please sign in to comment.