From da6e9f562d3a8d379a9600af4ff8bcb66fd651b6 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 29 Feb 2024 23:33:10 -0800 Subject: [PATCH 01/12] Preserve zero size input tensor dims. --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 74 +++++--- js/web/test/data/ops/concat.jsonc | 196 ++++++++++++++++++++++ 2 files changed, 249 insertions(+), 21 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index b142a82e551a7..59fd306239729 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -13,14 +13,14 @@ export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; } -const validateInputs = (inputs: readonly TensorView[]): void => { +const validateInputs = (inputs: readonly TensorView[], referenceIndex: number): void => { if (!inputs || inputs.length < 1) { throw new Error('too few inputs'); } - const inputType = inputs[0].dataType; - const inputDimensionality = inputs[0].dims.length; - + const inputType = inputs[referenceIndex].dataType; + const inputDimensionality = inputs[referenceIndex].dims.length; + const referenceInput = inputs[referenceIndex]; for (const input of inputs) { // make sure types of all inputs match if (input.dataType !== inputType) { @@ -28,7 +28,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } // make sure the dimensionality of all inputs are the same - if (input.dims.length !== inputDimensionality) { + if (input.dims.length !== inputDimensionality && ShapeUtil.size(input.dims) > 0 && + ShapeUtil.size(referenceInput.dims) > 0) { throw new Error('input tensors should have the same shape'); } } @@ -64,29 +65,47 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); +const computeReferenceIndex = (inputs: readonly TensorView[]): number => { + // find a none zero tensor to determine the output shape + let referenceIndex = 0; + for (let j = 0; j < inputs.length; j++) { + const size = ShapeUtil.size(inputs[j].dims); + if (size > 0) { + referenceIndex = j; + break; + } + } + return referenceIndex; +}; + +const computeOutputShape = (inputs: readonly TensorView[], axis: number, referenceIndex: number): number[] => { + const inputShape = inputs[referenceIndex].dims.slice(); if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { throw new Error('axis specified for concat doesn\'t match input dimensionality'); } - const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; // ensure all of the non-concatenated axes match each other // calculate the shape of the output tensor while we do that const outputShape = inputShape.slice(0); - for (let i = 1; i < inputs.length; i++) { + for (let i = 0; i < inputs.length; i++) { + if (i === referenceIndex) { + continue; + } const dataNShape = inputs[i].dims.slice(); for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { // add to the placeholder for computing output shape - if (axisIndex === adjustedAxis) { - outputShape[adjustedAxis] += dataNShape[axisIndex]; + if (axisIndex === axis) { + outputShape[axis] += dataNShape[axisIndex]; } // ensure all non-cancatenated axes match each other - else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { + else if (inputShape[axisIndex] !== dataNShape[axisIndex] && ShapeUtil.size(dataNShape) > 0) { throw new Error('non concat dimensions must match'); } } } + return outputShape; +}; +const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number, outputShape: number[]): ProgramInfo => { const outputSize = ShapeUtil.size(outputShape); const sizeInConcatAxis = new Array(inputs.length); @@ -98,7 +117,7 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P const inputRanks = []; const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { - previousSum += inputs[i].dims[adjustedAxis]; + previousSum += inputs[i].dims[axis]; sizeInConcatAxis[i] = previousSum; inputRanks.push(inputs[i].dims.length); inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); @@ -111,7 +130,7 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P programUniforms.push(...createTensorShapeVariables(outputShape)); const output = outputVariable('output', dataType, outputShape.length); - const indicesAxis = output.indicesGet('indices', adjustedAxis); + const indicesAxis = output.indicesGet('indices', axis); const sizeInConcatAxisStr = Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -132,12 +151,16 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P var indices = ${output.offsetToIndices('global_idx')}; let inputIndex = calculateInputIndex(${indicesAxis}); - if (inputIndex != 0u) { - let sizeInConcatAxis = array(${sizeInConcatAxisStr}); - ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; - } + if (inputIndex < ${inputs.length}u) { + if (inputIndex != 0u) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); + ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; + } - ${assignOutputData(inputVars, output)} + ${assignOutputData(inputVars, output)} + } else { + ${output.setByOffset('global_idx', '0')} + } }`; return { @@ -153,10 +176,19 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { - validateInputs(context.inputs); + const referenceIndex = computeReferenceIndex(context.inputs); + validateInputs(context.inputs, referenceIndex); + const axis = attributes.axis; + const inputShape = context.inputs[referenceIndex].dims; + const adjustedAxis = (attributes.axis < 0) ? inputShape.length + axis : axis; + const outputShape = computeOutputShape(context.inputs, adjustedAxis, referenceIndex); // 0 length tensors are valid for concat, remove them const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); - context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs}); + if (nonEmptyInputs.length > 0) { + context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis, outputShape), {inputs: nonEmptyInputs}); + } else { + context.output(0, outputShape); + } }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/test/data/ops/concat.jsonc b/js/web/test/data/ops/concat.jsonc index d98376a72e8b5..b210ce8d82877 100644 --- a/js/web/test/data/ops/concat.jsonc +++ b/js/web/test/data/ops/concat.jsonc @@ -402,5 +402,201 @@ ] } ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "Some but not all input tensors have 0 in dims along the other axis", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 0, 0, 0, 0, 0], + "dims": [1, 7], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 0, + "type": "int" + } + ], + "cases": [ + { + "name": "Some but not all input tensors have 0 in dims along the axis", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "All input tensors have 0 in dims along the other axis", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 6], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "Zero input tensor rank is different from the other input tensors; zero dim along the axis", + "inputs": [ + { + "data": [], + "dims": [0, 1, 1], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ] + }, + { + "name": "Zero input tensor rank is different from the other input tensors", + "inputs": [ + { + "data": [], + "dims": [1, 1, 0], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0], + "dims": [2, 1], + "type": "float32" + } + ] + } + ] } ] From ae01d115547c61e9c752cc453d7f1d94a4510244 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 1 Mar 2024 09:02:54 -0800 Subject: [PATCH 02/12] Use adjustedAxis instead of axis --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 59fd306239729..a5471231c6198 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -185,7 +185,7 @@ export const concat = (context: ComputeContext, attributes: ConcatAttributes): v // 0 length tensors are valid for concat, remove them const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); if (nonEmptyInputs.length > 0) { - context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis, outputShape), {inputs: nonEmptyInputs}); + context.compute(createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape), {inputs: nonEmptyInputs}); } else { context.output(0, outputShape); } From 2c0d67cd73cf93e0adffc38b70593e8086d8564e Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 1 Mar 2024 11:46:37 -0800 Subject: [PATCH 03/12] Refactor; add more tests. --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 26 +++++----- js/web/test/data/ops/concat.jsonc | 58 +++++++++++++++++++++++ 2 files changed, 69 insertions(+), 15 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index a5471231c6198..4d973119ae505 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -65,18 +65,6 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -const computeReferenceIndex = (inputs: readonly TensorView[]): number => { - // find a none zero tensor to determine the output shape - let referenceIndex = 0; - for (let j = 0; j < inputs.length; j++) { - const size = ShapeUtil.size(inputs[j].dims); - if (size > 0) { - referenceIndex = j; - break; - } - } - return referenceIndex; -}; const computeOutputShape = (inputs: readonly TensorView[], axis: number, referenceIndex: number): number[] => { const inputShape = inputs[referenceIndex].dims.slice(); @@ -176,11 +164,19 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number, ou }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { - const referenceIndex = computeReferenceIndex(context.inputs); + // find a none zero tensor to determine the output shape + // Choose input with max rank if all input tensors are zero size to make the output shape independent of the order of + // the inputs. + let referenceIndex = context.inputs.findIndex(input => ShapeUtil.size(input.dims) > 0); + if (referenceIndex === -1) { + referenceIndex = + context.inputs.map(input => input.dims.length) + .reduce((maxRankIndex, rank, index, array) => rank > array[maxRankIndex] ? index : maxRankIndex, 0); + } + validateInputs(context.inputs, referenceIndex); - const axis = attributes.axis; const inputShape = context.inputs[referenceIndex].dims; - const adjustedAxis = (attributes.axis < 0) ? inputShape.length + axis : axis; + const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0); const outputShape = computeOutputShape(context.inputs, adjustedAxis, referenceIndex); // 0 length tensors are valid for concat, remove them const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); diff --git a/js/web/test/data/ops/concat.jsonc b/js/web/test/data/ops/concat.jsonc index b210ce8d82877..da39954dead2a 100644 --- a/js/web/test/data/ops/concat.jsonc +++ b/js/web/test/data/ops/concat.jsonc @@ -598,5 +598,63 @@ ] } ] + }, + { + "name": "Concat 2D axis=0; Preserve dims", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "All input tensors have 0 in dims along the axis", + "inputs": [ + { + "data": [], + "dims": [0, 1, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "All input tensors have 0 in dims along the other axis", + "inputs": [ + { + "data": [], + "dims": [0, 1, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 2, 1], + "type": "float32" + } + ] + } + ] } ] From ce6a00473b6135eb6e73da5f6880071fa5bb67fb Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 1 Mar 2024 13:17:28 -0800 Subject: [PATCH 04/12] Moved the new concat test cases from concat.jsonc to concat_zero-sized.jsonc --- js/web/test/data/ops/concat.jsonc | 254 ------------------- js/web/test/data/ops/concat_zero-sized.jsonc | 254 +++++++++++++++++++ 2 files changed, 254 insertions(+), 254 deletions(-) diff --git a/js/web/test/data/ops/concat.jsonc b/js/web/test/data/ops/concat.jsonc index da39954dead2a..d98376a72e8b5 100644 --- a/js/web/test/data/ops/concat.jsonc +++ b/js/web/test/data/ops/concat.jsonc @@ -402,259 +402,5 @@ ] } ] - }, - { - "name": "Concat 2D axis=1; Preserve dims", - "operator": "Concat", - "attributes": [ - { - "name": "axis", - "data": 1, - "type": "int" - } - ], - "cases": [ - { - "name": "Some but not all input tensors have 0 in dims along the other axis", - "inputs": [ - { - "data": [], - "dims": [0, 0], - "type": "float32" - }, - { - "data": [], - "dims": [0, 1], - "type": "float32" - }, - { - "data": [], - "dims": [0, 2], - "type": "float32" - }, - { - "data": [], - "dims": [0, 3], - "type": "float32" - }, - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [1, 0, 0, 0, 0, 0, 0], - "dims": [1, 7], - "type": "float32" - } - ] - } - ] - }, - { - "name": "Concat 2D axis=1; Preserve dims", - "operator": "Concat", - "attributes": [ - { - "name": "axis", - "data": 0, - "type": "int" - } - ], - "cases": [ - { - "name": "Some but not all input tensors have 0 in dims along the axis", - "inputs": [ - { - "data": [], - "dims": [0, 0], - "type": "float32" - }, - { - "data": [], - "dims": [0, 1], - "type": "float32" - }, - { - "data": [], - "dims": [0, 2], - "type": "float32" - }, - { - "data": [], - "dims": [0, 3], - "type": "float32" - }, - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ] - } - ] - }, - { - "name": "Concat 2D axis=1; Preserve dims", - "operator": "Concat", - "attributes": [ - { - "name": "axis", - "data": 1, - "type": "int" - } - ], - "cases": [ - { - "name": "All input tensors have 0 in dims along the other axis", - "inputs": [ - { - "data": [], - "dims": [0, 0], - "type": "float32" - }, - { - "data": [], - "dims": [0, 1], - "type": "float32" - }, - { - "data": [], - "dims": [0, 2], - "type": "float32" - }, - { - "data": [], - "dims": [0, 3], - "type": "float32" - } - ], - "outputs": [ - { - "data": [], - "dims": [0, 6], - "type": "float32" - } - ] - } - ] - }, - { - "name": "Concat 2D axis=1; Preserve dims", - "operator": "Concat", - "attributes": [{ "name": "axis", "data": 0, "type": "int" }], - "cases": [ - { - "name": "Zero input tensor rank is different from the other input tensors; zero dim along the axis", - "inputs": [ - { - "data": [], - "dims": [0, 1, 1], - "type": "float32" - }, - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ] - }, - { - "name": "Zero input tensor rank is different from the other input tensors", - "inputs": [ - { - "data": [], - "dims": [1, 1, 0], - "type": "float32" - }, - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [1, 0], - "dims": [2, 1], - "type": "float32" - } - ] - } - ] - }, - { - "name": "Concat 2D axis=0; Preserve dims", - "operator": "Concat", - "attributes": [{ "name": "axis", "data": 0, "type": "int" }], - "cases": [ - { - "name": "All input tensors have 0 in dims along the axis", - "inputs": [ - { - "data": [], - "dims": [0, 1, 1], - "type": "float32" - }, - { - "data": [], - "dims": [0, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [], - "dims": [0, 1, 1], - "type": "float32" - } - ] - } - ] - }, - { - "name": "Concat 2D axis=1; Preserve dims", - "operator": "Concat", - "attributes": [{ "name": "axis", "data": 1, "type": "int" }], - "cases": [ - { - "name": "All input tensors have 0 in dims along the other axis", - "inputs": [ - { - "data": [], - "dims": [0, 1, 1], - "type": "float32" - }, - { - "data": [], - "dims": [0, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [], - "dims": [0, 2, 1], - "type": "float32" - } - ] - } - ] } ] diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc index 7be8e8c1cc602..d393cdc6f93e1 100644 --- a/js/web/test/data/ops/concat_zero-sized.jsonc +++ b/js/web/test/data/ops/concat_zero-sized.jsonc @@ -557,5 +557,259 @@ ] } ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "Some but not all input tensors have 0 in dims along the other axis", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 0, 0, 0, 0, 0], + "dims": [1, 7], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 0, + "type": "int" + } + ], + "cases": [ + { + "name": "Some but not all input tensors have 0 in dims along the axis", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "All input tensors have 0 in dims along the other axis", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 6], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "Zero input tensor rank is different from the other input tensors; zero dim along the axis", + "inputs": [ + { + "data": [], + "dims": [0, 1, 1], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ] + }, + { + "name": "Zero input tensor rank is different from the other input tensors", + "inputs": [ + { + "data": [], + "dims": [1, 1, 0], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0], + "dims": [2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=0; Preserve dims", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "All input tensors have 0 in dims along the axis", + "inputs": [ + { + "data": [], + "dims": [0, 1, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "All input tensors have 0 in dims along the other axis", + "inputs": [ + { + "data": [], + "dims": [0, 1, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 2, 1], + "type": "float32" + } + ] + } + ] } ] From 51c83d7332a716d27d7a5327d1b963a859c2ef33 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 4 Mar 2024 15:45:41 -0800 Subject: [PATCH 05/12] Removed calculateOutputShape function. --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 68 +++++++++-------------- 1 file changed, 25 insertions(+), 43 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 4d973119ae505..9c291aff2b0a2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -13,26 +13,35 @@ export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; } -const validateInputs = (inputs: readonly TensorView[], referenceIndex: number): void => { +const validateInputs = (inputs: readonly TensorView[], referenceIndex: number, axis: number): void => { if (!inputs || inputs.length < 1) { throw new Error('too few inputs'); } - const inputType = inputs[referenceIndex].dataType; - const inputDimensionality = inputs[referenceIndex].dims.length; const referenceInput = inputs[referenceIndex]; - for (const input of inputs) { + const inputType = referenceInput.dataType; + const inputRank = referenceInput.dims.length; + const referenceInputSize = ShapeUtil.size(referenceInput.dims); + inputs.forEach((input, i) => { + if (i === referenceIndex) { + return; + } // make sure types of all inputs match if (input.dataType !== inputType) { throw new Error('input tensors should be one type'); } - - // make sure the dimensionality of all inputs are the same - if (input.dims.length !== inputDimensionality && ShapeUtil.size(input.dims) > 0 && - ShapeUtil.size(referenceInput.dims) > 0) { - throw new Error('input tensors should have the same shape'); + if (referenceInputSize > 0 && ShapeUtil.size(input.dims) > 0) { + // make sure the dimensionality of all inputs are the same + if (input.dims.length !== inputRank) { + throw new Error('input tensors should have the same shape'); + } + input.dims.forEach((dim, i) => { + if (i !== axis && dim !== referenceInput.dims[i]) { + throw new Error('non concat dimensions must match'); + } + }); } - } + }); }; const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` @@ -65,34 +74,6 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; - -const computeOutputShape = (inputs: readonly TensorView[], axis: number, referenceIndex: number): number[] => { - const inputShape = inputs[referenceIndex].dims.slice(); - if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { - throw new Error('axis specified for concat doesn\'t match input dimensionality'); - } - // ensure all of the non-concatenated axes match each other - // calculate the shape of the output tensor while we do that - const outputShape = inputShape.slice(0); - for (let i = 0; i < inputs.length; i++) { - if (i === referenceIndex) { - continue; - } - const dataNShape = inputs[i].dims.slice(); - for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { - // add to the placeholder for computing output shape - if (axisIndex === axis) { - outputShape[axis] += dataNShape[axisIndex]; - } - // ensure all non-cancatenated axes match each other - else if (inputShape[axisIndex] !== dataNShape[axisIndex] && ShapeUtil.size(dataNShape) > 0) { - throw new Error('non concat dimensions must match'); - } - } - } - return outputShape; -}; - const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number, outputShape: number[]): ProgramInfo => { const outputSize = ShapeUtil.size(outputShape); @@ -169,15 +150,16 @@ export const concat = (context: ComputeContext, attributes: ConcatAttributes): v // the inputs. let referenceIndex = context.inputs.findIndex(input => ShapeUtil.size(input.dims) > 0); if (referenceIndex === -1) { - referenceIndex = - context.inputs.map(input => input.dims.length) - .reduce((maxRankIndex, rank, index, array) => rank > array[maxRankIndex] ? index : maxRankIndex, 0); + referenceIndex = context.inputs.reduce( + (maxRankIndex, input, index, array) => input.dims > array[maxRankIndex].dims ? index : maxRankIndex, 0); } - validateInputs(context.inputs, referenceIndex); + validateInputs(context.inputs, referenceIndex, attributes.axis); const inputShape = context.inputs[referenceIndex].dims; const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0); - const outputShape = computeOutputShape(context.inputs, adjustedAxis, referenceIndex); + const outputShape = inputShape.slice(); + outputShape[adjustedAxis] = + context.inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); // 0 length tensors are valid for concat, remove them const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); if (nonEmptyInputs.length > 0) { From 2bfc23ebc164e77f49117aa1806ca2dea35ac663 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 4 Mar 2024 16:51:49 -0800 Subject: [PATCH 06/12] Removed special case handling of all inputs zero-sized --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 110 +++++++++++----------- 1 file changed, 54 insertions(+), 56 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 9c291aff2b0a2..2c01f4a26d8f7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -74,43 +74,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number, outputShape: number[]): ProgramInfo => { - const outputSize = ShapeUtil.size(outputShape); - - const sizeInConcatAxis = new Array(inputs.length); - const inputVars = new Array(inputs.length); - const dataType = inputs[0].dataType; - - let previousSum = 0; - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - const inputRanks = []; - const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; - for (let i = 0; i < inputs.length; ++i) { - previousSum += inputs[i].dims[axis]; - sizeInConcatAxis[i] = previousSum; - inputRanks.push(inputs[i].dims.length); - inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); - inputDependencies.push('rank'); - programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); - } - for (let i = 0; i < inputs.length; ++i) { - programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); - } - programUniforms.push(...createTensorShapeVariables(outputShape)); +const createConcatProgramInfo = + (inputs: readonly TensorView[], axis: number, outputShape: number[], dataType: DataType): ProgramInfo => { + const outputSize = ShapeUtil.size(outputShape); + + const sizeInConcatAxis = new Array(inputs.length); + const inputVars = new Array(inputs.length); + + let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + const inputRanks = []; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[axis]; + sizeInConcatAxis[i] = previousSum; + inputRanks.push(inputs[i].dims.length); + inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); + inputDependencies.push('rank'); + programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); + } + for (let i = 0; i < inputs.length; ++i) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const output = outputVariable('output', dataType, outputShape.length); - const indicesAxis = output.indicesGet('indices', axis); - const sizeInConcatAxisStr = - Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const output = outputVariable('output', dataType, outputShape.length); + const indicesAxis = output.indicesGet('indices', axis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${(() => { - shaderHelper.registerUniform('outputSize', 'u32'); - for (let i = 0; i < inputs.length; i++) { - shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); - } - return shaderHelper.declareVariables(...inputVars, output); - })()} + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output); + })()} ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} @@ -132,41 +132,39 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number, ou } }`; - return { - name: 'Concat', - shaderCache: {hint: `${axis}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms, - }), - getShaderSource, - }; -}; + return { + name: 'Concat', + shaderCache: {hint: `${axis}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms, + }), + getShaderSource, + }; + }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { // find a none zero tensor to determine the output shape // Choose input with max rank if all input tensors are zero size to make the output shape independent of the order of // the inputs. - let referenceIndex = context.inputs.findIndex(input => ShapeUtil.size(input.dims) > 0); + const inputs = context.inputs; + let referenceIndex = inputs.findIndex(input => ShapeUtil.size(input.dims) > 0); if (referenceIndex === -1) { - referenceIndex = context.inputs.reduce( + referenceIndex = inputs.reduce( (maxRankIndex, input, index, array) => input.dims > array[maxRankIndex].dims ? index : maxRankIndex, 0); } - validateInputs(context.inputs, referenceIndex, attributes.axis); - const inputShape = context.inputs[referenceIndex].dims; + validateInputs(inputs, referenceIndex, attributes.axis); + const inputShape = inputs[referenceIndex].dims; const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0); const outputShape = inputShape.slice(); outputShape[adjustedAxis] = - context.inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); + inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); // 0 length tensors are valid for concat, remove them - const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); - if (nonEmptyInputs.length > 0) { - context.compute(createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape), {inputs: nonEmptyInputs}); - } else { - context.output(0, outputShape); - } + const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute( + createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => From e8fb4d28b01927f7238ffb9dc8d026598ceeb53f Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 4 Mar 2024 17:10:01 -0800 Subject: [PATCH 07/12] Use adjusted axis to validate. --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 2c01f4a26d8f7..1d8cb1785f17c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -155,9 +155,9 @@ export const concat = (context: ComputeContext, attributes: ConcatAttributes): v (maxRankIndex, input, index, array) => input.dims > array[maxRankIndex].dims ? index : maxRankIndex, 0); } - validateInputs(inputs, referenceIndex, attributes.axis); const inputShape = inputs[referenceIndex].dims; const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0); + validateInputs(inputs, referenceIndex, adjustedAxis); const outputShape = inputShape.slice(); outputShape[adjustedAxis] = inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); From f19ece3644c0b93d62469a2aa20a91a8449c2630 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 6 Mar 2024 15:16:21 -0800 Subject: [PATCH 08/12] Rollback code allowing zero-sized input non-concat axes dims mismatch referance input --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 48 +++-- js/web/test/data/ops/concat_zero-sized.jsonc | 178 +------------------ 2 files changed, 23 insertions(+), 203 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 1d8cb1785f17c..2559eb04d9bf1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -21,7 +21,6 @@ const validateInputs = (inputs: readonly TensorView[], referenceIndex: number, a const referenceInput = inputs[referenceIndex]; const inputType = referenceInput.dataType; const inputRank = referenceInput.dims.length; - const referenceInputSize = ShapeUtil.size(referenceInput.dims); inputs.forEach((input, i) => { if (i === referenceIndex) { return; @@ -30,17 +29,15 @@ const validateInputs = (inputs: readonly TensorView[], referenceIndex: number, a if (input.dataType !== inputType) { throw new Error('input tensors should be one type'); } - if (referenceInputSize > 0 && ShapeUtil.size(input.dims) > 0) { - // make sure the dimensionality of all inputs are the same - if (input.dims.length !== inputRank) { - throw new Error('input tensors should have the same shape'); - } - input.dims.forEach((dim, i) => { - if (i !== axis && dim !== referenceInput.dims[i]) { - throw new Error('non concat dimensions must match'); - } - }); + // make sure the dimensionality of all inputs are the same + if (input.dims.length !== inputRank) { + throw new Error('input tensors should have the same shape'); } + input.dims.forEach((dim, i) => { + if (i !== axis && dim !== referenceInput.dims[i]) { + throw new Error('non concat dimensions must match'); + } + }); }); }; @@ -120,16 +117,12 @@ const createConcatProgramInfo = var indices = ${output.offsetToIndices('global_idx')}; let inputIndex = calculateInputIndex(${indicesAxis}); - if (inputIndex < ${inputs.length}u) { - if (inputIndex != 0u) { - let sizeInConcatAxis = array(${sizeInConcatAxisStr}); - ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; - } - - ${assignOutputData(inputVars, output)} - } else { - ${output.setByOffset('global_idx', '0')} + if (inputIndex != 0u) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); + ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; } + + ${assignOutputData(inputVars, output)} }`; return { @@ -145,14 +138,15 @@ const createConcatProgramInfo = }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { - // find a none zero tensor to determine the output shape - // Choose input with max rank if all input tensors are zero size to make the output shape independent of the order of - // the inputs. + // find a none zero tensor as reference to determine the output shape + // choose input 0 as reference if all input tensors are zero-sized. const inputs = context.inputs; - let referenceIndex = inputs.findIndex(input => ShapeUtil.size(input.dims) > 0); - if (referenceIndex === -1) { - referenceIndex = inputs.reduce( - (maxRankIndex, input, index, array) => input.dims > array[maxRankIndex].dims ? index : maxRankIndex, 0); + let referenceIndex = 0; + for (let i = 0; i < inputs.length; i++) { + if (ShapeUtil.size(inputs[i].dims) > 0) { + referenceIndex = i; + break; + } } const inputShape = inputs[referenceIndex].dims; diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc index d393cdc6f93e1..be9625145d157 100644 --- a/js/web/test/data/ops/concat_zero-sized.jsonc +++ b/js/web/test/data/ops/concat_zero-sized.jsonc @@ -558,56 +558,6 @@ } ] }, - { - "name": "Concat 2D axis=1; Preserve dims", - "operator": "Concat", - "attributes": [ - { - "name": "axis", - "data": 1, - "type": "int" - } - ], - "cases": [ - { - "name": "Some but not all input tensors have 0 in dims along the other axis", - "inputs": [ - { - "data": [], - "dims": [0, 0], - "type": "float32" - }, - { - "data": [], - "dims": [0, 1], - "type": "float32" - }, - { - "data": [], - "dims": [0, 2], - "type": "float32" - }, - { - "data": [], - "dims": [0, 3], - "type": "float32" - }, - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [1, 0, 0, 0, 0, 0, 0], - "dims": [1, 7], - "type": "float32" - } - ] - } - ] - }, { "name": "Concat 2D axis=1; Preserve dims", "operator": "Concat", @@ -620,28 +570,13 @@ ], "cases": [ { - "name": "Some but not all input tensors have 0 in dims along the axis", + "name": "Some but not all input tensors are zero-sized", "inputs": [ - { - "data": [], - "dims": [0, 0], - "type": "float32" - }, { "data": [], "dims": [0, 1], "type": "float32" }, - { - "data": [], - "dims": [0, 2], - "type": "float32" - }, - { - "data": [], - "dims": [0, 3], - "type": "float32" - }, { "data": [1], "dims": [1, 1], @@ -670,7 +605,7 @@ ], "cases": [ { - "name": "All input tensors have 0 in dims along the other axis", + "name": "All input tensors are zero-sized", "inputs": [ { "data": [], @@ -702,114 +637,5 @@ ] } ] - }, - { - "name": "Concat 2D axis=1; Preserve dims", - "operator": "Concat", - "attributes": [{ "name": "axis", "data": 0, "type": "int" }], - "cases": [ - { - "name": "Zero input tensor rank is different from the other input tensors; zero dim along the axis", - "inputs": [ - { - "data": [], - "dims": [0, 1, 1], - "type": "float32" - }, - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ] - }, - { - "name": "Zero input tensor rank is different from the other input tensors", - "inputs": [ - { - "data": [], - "dims": [1, 1, 0], - "type": "float32" - }, - { - "data": [1], - "dims": [1, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [1, 0], - "dims": [2, 1], - "type": "float32" - } - ] - } - ] - }, - { - "name": "Concat 2D axis=0; Preserve dims", - "operator": "Concat", - "attributes": [{ "name": "axis", "data": 0, "type": "int" }], - "cases": [ - { - "name": "All input tensors have 0 in dims along the axis", - "inputs": [ - { - "data": [], - "dims": [0, 1, 1], - "type": "float32" - }, - { - "data": [], - "dims": [0, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [], - "dims": [0, 1, 1], - "type": "float32" - } - ] - } - ] - }, - { - "name": "Concat 2D axis=1; Preserve dims", - "operator": "Concat", - "attributes": [{ "name": "axis", "data": 1, "type": "int" }], - "cases": [ - { - "name": "All input tensors have 0 in dims along the other axis", - "inputs": [ - { - "data": [], - "dims": [0, 1, 1], - "type": "float32" - }, - { - "data": [], - "dims": [0, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [], - "dims": [0, 2, 1], - "type": "float32" - } - ] - } - ] } ] From c0fab593e07979dbdbdc9c10e08005f838b2b9f9 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 6 Mar 2024 15:31:59 -0800 Subject: [PATCH 09/12] Keep the variable name adjustedAxis --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 2559eb04d9bf1..9e6b6159759ac 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -72,7 +72,7 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe }; const createConcatProgramInfo = - (inputs: readonly TensorView[], axis: number, outputShape: number[], dataType: DataType): ProgramInfo => { + (inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => { const outputSize = ShapeUtil.size(outputShape); const sizeInConcatAxis = new Array(inputs.length); @@ -83,7 +83,7 @@ const createConcatProgramInfo = const inputRanks = []; const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { - previousSum += inputs[i].dims[axis]; + previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; inputRanks.push(inputs[i].dims.length); inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); @@ -96,7 +96,7 @@ const createConcatProgramInfo = programUniforms.push(...createTensorShapeVariables(outputShape)); const output = outputVariable('output', dataType, outputShape.length); - const indicesAxis = output.indicesGet('indices', axis); + const indicesAxis = output.indicesGet('indices', adjustedAxis); const sizeInConcatAxisStr = Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -127,7 +127,7 @@ const createConcatProgramInfo = return { name: 'Concat', - shaderCache: {hint: `${axis}`, inputDependencies}, + shaderCache: {hint: `${adjustedAxis}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, From ed048b095bfc4f5b2bc6714cd0edcf175f7ff60e Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 6 Mar 2024 17:49:00 -0800 Subject: [PATCH 10/12] Removed referenceInput. --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 9e6b6159759ac..fc1b63683ad1a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -13,11 +13,11 @@ export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; } -const validateInputs = (inputs: readonly TensorView[], referenceIndex: number, axis: number): void => { +const validateInputs = (inputs: readonly TensorView[], axis: number): void => { if (!inputs || inputs.length < 1) { throw new Error('too few inputs'); } - + const referenceIndex = 0; const referenceInput = inputs[referenceIndex]; const inputType = referenceInput.dataType; const inputRank = referenceInput.dims.length; @@ -141,17 +141,10 @@ export const concat = (context: ComputeContext, attributes: ConcatAttributes): v // find a none zero tensor as reference to determine the output shape // choose input 0 as reference if all input tensors are zero-sized. const inputs = context.inputs; - let referenceIndex = 0; - for (let i = 0; i < inputs.length; i++) { - if (ShapeUtil.size(inputs[i].dims) > 0) { - referenceIndex = i; - break; - } - } - const inputShape = inputs[referenceIndex].dims; + const inputShape = inputs[0].dims; const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0); - validateInputs(inputs, referenceIndex, adjustedAxis); + validateInputs(inputs, adjustedAxis); const outputShape = inputShape.slice(); outputShape[adjustedAxis] = inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); From 971b6121a97bcc7b6248ab8245813f8edbe26296 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 6 Mar 2024 18:02:15 -0800 Subject: [PATCH 11/12] Fixed comment --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 3 --- 1 file changed, 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index fc1b63683ad1a..97ead2e6624a4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -138,10 +138,7 @@ const createConcatProgramInfo = }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { - // find a none zero tensor as reference to determine the output shape - // choose input 0 as reference if all input tensors are zero-sized. const inputs = context.inputs; - const inputShape = inputs[0].dims; const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0); validateInputs(inputs, adjustedAxis); From adfca56b164a2290627c81ced6385d0ffd142a81 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 7 Mar 2024 16:40:12 -0800 Subject: [PATCH 12/12] Check if the axis is with in the range. --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 97ead2e6624a4..010ee589c44fa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -140,7 +140,7 @@ const createConcatProgramInfo = export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { const inputs = context.inputs; const inputShape = inputs[0].dims; - const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0); + const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); validateInputs(inputs, adjustedAxis); const outputShape = inputShape.slice(); outputShape[adjustedAxis] =