Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WebGPU] Preserve zero size input tensor dims. #19737

Merged
merged 12 commits into from
Mar 8, 2024
Prev Previous commit
Next Next commit
Refactor; add more tests.
  • Loading branch information
satyajandhyala committed Mar 1, 2024
commit 2c0d67cd73cf93e0adffc38b70593e8086d8564e
26 changes: 11 additions & 15 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
@@ -65,18 +65,6 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
return codeLines.join('\n');
};

const computeReferenceIndex = (inputs: readonly TensorView[]): number => {
// find a none zero tensor to determine the output shape
let referenceIndex = 0;
for (let j = 0; j < inputs.length; j++) {
const size = ShapeUtil.size(inputs[j].dims);
if (size > 0) {
referenceIndex = j;
break;
}
}
return referenceIndex;
};

const computeOutputShape = (inputs: readonly TensorView[], axis: number, referenceIndex: number): number[] => {
const inputShape = inputs[referenceIndex].dims.slice();
@@ -176,11 +164,19 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number, ou
};

export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
const referenceIndex = computeReferenceIndex(context.inputs);
// find a none zero tensor to determine the output shape
// Choose input with max rank if all input tensors are zero size to make the output shape independent of the order of
// the inputs.
let referenceIndex = context.inputs.findIndex(input => ShapeUtil.size(input.dims) > 0);
if (referenceIndex === -1) {
referenceIndex =
context.inputs.map(input => input.dims.length)
.reduce((maxRankIndex, rank, index, array) => rank > array[maxRankIndex] ? index : maxRankIndex, 0);
}

validateInputs(context.inputs, referenceIndex);
const axis = attributes.axis;
const inputShape = context.inputs[referenceIndex].dims;
const adjustedAxis = (attributes.axis < 0) ? inputShape.length + axis : axis;
const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0);
const outputShape = computeOutputShape(context.inputs, adjustedAxis, referenceIndex);
// 0 length tensors are valid for concat, remove them
const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
58 changes: 58 additions & 0 deletions js/web/test/data/ops/concat.jsonc
Original file line number Diff line number Diff line change
@@ -598,5 +598,63 @@
]
}
]
},
{
"name": "Concat 2D axis=0; Preserve dims",
"operator": "Concat",
"attributes": [{ "name": "axis", "data": 0, "type": "int" }],
"cases": [
{
"name": "All input tensors have 0 in dims along the axis",
"inputs": [
{
"data": [],
"dims": [0, 1, 1],
"type": "float32"
},
{
"data": [],
"dims": [0, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [],
"dims": [0, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "All input tensors have 0 in dims along the other axis",
"inputs": [
{
"data": [],
"dims": [0, 1, 1],
"type": "float32"
},
{
"data": [],
"dims": [0, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [],
"dims": [0, 2, 1],
"type": "float32"
}
]
}
]
}
]