Skip to content

Commit

Permalink
fix webgpu split (#17258)
Browse files Browse the repository at this point in the history
fix webgpu split for the case of split_sizes coming from input[1]
  • Loading branch information
guschmue authored and centwang committed Aug 28, 2023
1 parent 83fbd55 commit 6462ee6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/split.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
const createSplitAttributesFromInputs =
(inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => {
const splitSizes: number[] = [];
let numOutputs: number = attributes.numOutputs;
if (inputs[1].dims[0] > 0) {
inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v)));
numOutputs = splitSizes.length;
}
return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes});
return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes});
};

const calculateOutputIndexImpl = (numberOfTensors: number): string => `
Expand Down Expand Up @@ -114,7 +116,7 @@ const createSplitProgramInfoLoader =
const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes);
const metadata:
ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)};
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)};
};

export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/js/operators/split.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ class Split : public JsKernel, public SplitBase {
if (num_outputs_ < 0) {
num_outputs_ = split_sizes.size();
}
} else if (split_sizes_.size() == 0) {
// Compute split_sizes from input shape and num_outputs
} else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) {
// Compute split_sizes from input shape and num_outputs.
// TODO: Shape might not be known at this point, better to handle this in javascript
auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast<int32_t>(axis_)).dim_value();
int64_t split_size_sum = 0;
if (num_outputs_ < 0) {
Expand All @@ -44,6 +45,7 @@ class Split : public JsKernel, public SplitBase {
ORT_ENFORCE(split_size_sum == total_split_size,
"Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")");
}
// else: let javascript handle all other cases, ie. split_sizes come as input[1]

JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1,
"numOutputs" : $2,
Expand Down

0 comments on commit 6462ee6

Please sign in to comment.