Skip to content

Commit

Permalink
fix webgpu split (microsoft#17258)
Browse files Browse the repository at this point in the history
fix webgpu split for the case of split_sizes coming from input[1]
  • Loading branch information
guschmue authored Aug 22, 2023
1 parent b10674c commit b5f69d0
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions web/lib/wasm/jsep/webgpu/ops/split.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
const createSplitAttributesFromInputs =
(inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => {
const splitSizes: number[] = [];
let numOutputs: number = attributes.numOutputs;
if (inputs[1].dims[0] > 0) {
inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v)));
numOutputs = splitSizes.length;
}
return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes});
return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes});
};

const calculateOutputIndexImpl = (numberOfTensors: number): string => `
Expand Down Expand Up @@ -114,7 +116,7 @@ const createSplitProgramInfoLoader =
const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes);
const metadata:
ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)};
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)};
};

export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
Expand Down

0 comments on commit b5f69d0

Please sign in to comment.