From 6adfa140b3dfa1c8006097416d778dd58c837e87 Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 22 Aug 2023 10:53:17 -0700 Subject: [PATCH 1/4] fix webgpu split --- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 6 ++++-- onnxruntime/core/providers/js/operators/split.h | 7 +++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 54f493422816f..c7863ea91e52f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -10,8 +10,8 @@ import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './comm export interface SplitAttributes extends AttributeWithCacheKey { readonly axis: number; - readonly numOutputs: number; - readonly splitSizes: number[]; + numOutputs: number; + splitSizes: number[]; } const validateInputs = (inputs: readonly TensorView[]): void => { @@ -25,6 +25,8 @@ const createSplitAttributesFromInputs = const splitSizes: number[] = []; if (inputs[1].dims[0] > 0) { inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); + attributes.splitSizes = splitSizes; + attributes.numOutputs = attributes.splitSizes.length; } return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes}); }; diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 691af48711a56..14a35371d8d63 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -25,8 +25,10 @@ class Split : public JsKernel, public SplitBase { if (num_outputs_ < 0) { num_outputs_ = split_sizes.size(); } - } else if (split_sizes_.size() == 0) { - // Compute split_sizes from input shape and num_outputs + } + else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) { + // Compute split_sizes from input shape and num_outputs. + // TODO: Shape might not be known at this point, better to handle this in javascript auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast(axis_)).dim_value(); int64_t split_size_sum = 0; if (num_outputs_ < 0) { @@ -44,6 +46,7 @@ class Split : public JsKernel, public SplitBase { ORT_ENFORCE(split_size_sum == total_split_size, "Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")"); } + // else: let javascript handle all other cases, ie. split_sizes come as input[1] JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, From b2b92fe8a9e91477eba7d03e6b47ddb28758ca01 Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 22 Aug 2023 11:55:19 -0700 Subject: [PATCH 2/4] fix lint --- onnxruntime/core/providers/js/operators/split.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 14a35371d8d63..cfacc1aa6a363 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -25,8 +25,7 @@ class Split : public JsKernel, public SplitBase { if (num_outputs_ < 0) { num_outputs_ = split_sizes.size(); } - } - else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) { + } else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) { // Compute split_sizes from input shape and num_outputs. // TODO: Shape might not be known at this point, better to handle this in javascript auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast(axis_)).dim_value(); From 1f2ee0baa3fc632b3338c55179dd0090bcea0691 Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 22 Aug 2023 13:30:45 -0700 Subject: [PATCH 3/4] avoid updateing attributes --- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index c7863ea91e52f..8e80851d987f2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -10,8 +10,8 @@ import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './comm export interface SplitAttributes extends AttributeWithCacheKey { readonly axis: number; - numOutputs: number; - splitSizes: number[]; + readonly numOutputs: number; + readonly splitSizes: number[]; } const validateInputs = (inputs: readonly TensorView[]): void => { @@ -23,12 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => { const createSplitAttributesFromInputs = (inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => { const splitSizes: number[] = []; + let numOutputs: number = attributes.numOutputs; if (inputs[1].dims[0] > 0) { inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); - attributes.splitSizes = splitSizes; - attributes.numOutputs = attributes.splitSizes.length; + numOutputs = splitSizes.length; } - return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes}); + return createAttributeWithCacheKey({numOutputs: numOutputs, axis: attributes.axis, splitSizes}); }; const calculateOutputIndexImpl = (numberOfTensors: number): string => ` @@ -116,7 +116,7 @@ const createSplitProgramInfoLoader = const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes); const metadata: ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; - return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)}; + return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)}; }; export const split = (context: ComputeContext, attributes: SplitAttributes): void => { From 27fe5b115441d5d54c5c8b110f1d4dd1c0ae98ea Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 22 Aug 2023 14:02:08 -0700 Subject: [PATCH 4/4] lint --- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 8e80851d987f2..f5b8a7e3b0ef9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -28,7 +28,7 @@ const createSplitAttributesFromInputs = inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); numOutputs = splitSizes.length; } - return createAttributeWithCacheKey({numOutputs: numOutputs, axis: attributes.axis, splitSizes}); + return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes}); }; const calculateOutputIndexImpl = (numberOfTensors: number): string => `