From 6462ee67eac4074cdac512e63c0319bb7ed32ebd Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 22 Aug 2023 16:49:22 -0700 Subject: [PATCH] fix webgpu split (#17258) fix webgpu split for the case of split_sizes coming from input[1] --- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 6 ++++-- onnxruntime/core/providers/js/operators/split.h | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 54f493422816f..f5b8a7e3b0ef9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => { const createSplitAttributesFromInputs = (inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => { const splitSizes: number[] = []; + let numOutputs: number = attributes.numOutputs; if (inputs[1].dims[0] > 0) { inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); + numOutputs = splitSizes.length; } - return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes}); + return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes}); }; const calculateOutputIndexImpl = (numberOfTensors: number): string => ` @@ -114,7 +116,7 @@ const createSplitProgramInfoLoader = const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes); const metadata: ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; - return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)}; + return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)}; }; export const split = (context: ComputeContext, attributes: SplitAttributes): void => { diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 691af48711a56..cfacc1aa6a363 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -25,8 +25,9 @@ class Split : public JsKernel, public SplitBase { if (num_outputs_ < 0) { num_outputs_ = split_sizes.size(); } - } else if (split_sizes_.size() == 0) { - // Compute split_sizes from input shape and num_outputs + } else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) { + // Compute split_sizes from input shape and num_outputs. + // TODO: Shape might not be known at this point, better to handle this in javascript auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast(axis_)).dim_value(); int64_t split_size_sum = 0; if (num_outputs_ < 0) { @@ -44,6 +45,7 @@ class Split : public JsKernel, public SplitBase { ORT_ENFORCE(split_size_sum == total_split_size, "Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")"); } + // else: let javascript handle all other cases, ie. split_sizes come as input[1] JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2,