From e004c5719a090d2c0a9cfc308b718b6bd92d4b21 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 7 Aug 2023 11:07:15 -0700 Subject: [PATCH 1/7] Changes --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 116 ++++++++++++++++++ .../providers/js/js_execution_provider.cc | 4 + .../core/providers/js/operators/tile.cc | 30 +++++ .../core/providers/js/operators/tile.h | 14 +++ 5 files changed, 166 insertions(+) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/tile.ts create mode 100644 onnxruntime/core/providers/js/operators/tile.cc create mode 100644 onnxruntime/core/providers/js/operators/tile.h diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 23b47033e548a..5a01a3128659d 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -17,6 +17,7 @@ import {parseResizeAttributes, resize} from './ops/resize'; import {parseSliceAttributes, slice} from './ops/slice'; import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; +import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; import {ComputeContext} from './types'; @@ -88,5 +89,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Tan', [unaryOps.tan]], ['Tanh', [unaryOps.tanh]], ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], + ['Tile', [tile]], ['Transpose', [transpose, parseTransposeAttributes]], ]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts new file mode 100644 index 0000000000000..c1ef143318503 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {createIndicesHelper, ShaderHelper} from './common'; + +export const tileProgramMetadata = { + name: 'Tile', + inputTypes: [GpuDataType.default] +}; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('Tile requires 2 inputs.'); + } + + if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.int32 && + inputs[0].dataType !== DataType.uint32) { + throw new Error('Tile only support float, int32, and uint32 data types'); + } + + if (inputs[1].dataType !== DataType.int64) { + throw new Error('Tile `repeats` input should be of int64 data type'); + } + + if (inputs[1].dims.length !== 1) { + throw new Error('Tile `repeats` input should be 1-D'); + } + + const repeats: number[] = []; + if (inputs[1].dims[0] > 0) { + inputs[1].getBigInt64Array().forEach(v => repeats.push(Number(v))); + } + + if (repeats.length !== inputs[0].dims.length) { + throw new Error('Tile `repeats` input should have same number of elements as rank of input data tensor'); + } +}; + +const getOutputShape = (inputShape: readonly number[], repeats: readonly number[]): readonly number[] => { + const outputShape: number[] = []; + + for (let i = 0; i < inputShape.length; ++i) { + outputShape.push(inputShape[i] * repeats[i]); + } + + return outputShape; +}; + +export const createTileProgramInfo = + (tileProgramMetadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + // We currently only support 4-byte element tensors, so using f32 here is safe + // TODO: support other data types for Tile + const dataType = 'f32'; + const inputShape = inputs[0].dims; + + const repeats: number[] = []; + if (inputs[1].dims[0] > 0) { + inputs[1].getBigInt64Array().forEach(v => repeats.push(Number(v))); + } + + const outputShape = getOutputShape(inputShape, repeats); + const outputSize = ShapeUtil.size(outputShape); + + const inputIndicesHelper = createIndicesHelper('input', inputShape); + const outputIndicesHelper = createIndicesHelper('output', outputShape); + + const isl = inputShape.length; + const calculateInputIndexImpl = (): string => ` + fn calculateInputIndex(inputIndices: ${inputIndicesHelper.iType}, outputIndices: ${ + outputIndicesHelper.iType}) -> void { + for (var i = 0; i < ${isl}; i++) { + // TODO: IndicesHelper should offer uniform way to get/set indices for all ranks + inputIndices${isl >= 2 ? '[i]' : ''} = (outputIndices${isl >= 2 ? '[i]' : ''} % inputShape[i]); + } + }`; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + + const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); + ${calculateInputIndexImpl()}; + @group(0) @binding(0) var input : array<${dataType}>; + @group(0) @binding(1) var output : array<${dataType}>; + + ${outputIndicesHelper.o2iImpl} + ${inputIndicesHelper.i2oImpl} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')} + ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} + calculateInputIndexImpl(inputIndices, outputIndices); + output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}]; + }`; + + return { + ...tileProgramMetadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +export const tile = (context: ComputeContext): void => { + validateInputs(context.inputs); + const cacheHint = context.inputs.map(x => x.dims.toString()).join('_'); + context.compute( + {...tileProgramMetadata, cacheHint, get: () => createTileProgramInfo(tileProgramMetadata, context.inputs)}, + {inputs: [0]}); +}; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 677a2543014ce..3831f37d69104 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -285,6 +285,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile); std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -506,6 +508,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; diff --git a/onnxruntime/core/providers/js/operators/tile.cc b/onnxruntime/core/providers/js/operators/tile.cc new file mode 100644 index 0000000000000..e1b693c043885 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/tile.cc @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "expand.h" + +namespace onnxruntime { +namespace js { +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Tile, + kOnnxDomain, + 6, + 12, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +ONNX_OPERATOR_KERNEL_EX( + Tile, + kOnnxDomain, + 13, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1), + Tile); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/tile.h b/onnxruntime/core/providers/js/operators/tile.h new file mode 100644 index 0000000000000..b3eac0ef828cd --- /dev/null +++ b/onnxruntime/core/providers/js/operators/tile.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +JSEP_KERNEL_IMPL(Tile, Tile); + +} // namespace js +} // namespace onnxruntime From 7c7ef6bc1036d7e541a180643384c838cf1ad285 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 10 Aug 2023 14:24:25 -0700 Subject: [PATCH 2/7] Modify header --- onnxruntime/core/providers/js/operators/tile.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/js/operators/tile.cc b/onnxruntime/core/providers/js/operators/tile.cc index e1b693c043885..6a05e5c9f7f41 100644 --- a/onnxruntime/core/providers/js/operators/tile.cc +++ b/onnxruntime/core/providers/js/operators/tile.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/js/js_kernel.h" -#include "expand.h" +#include "tile.h" namespace onnxruntime { namespace js { From ca92f03cc4513a06310638f4e8257e619b30d587 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 11 Aug 2023 11:46:44 -0700 Subject: [PATCH 3/7] Tile support --- js/web/docs/webgpu-operators.md | 1 + js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 61 ++++---- js/web/test/data/ops/tile.jsonc | 147 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 5 +- .../core/providers/js/operators/tile.cc | 8 +- 5 files changed, 192 insertions(+), 30 deletions(-) create mode 100644 js/web/test/data/ops/tile.jsonc diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 84bf69b51fe0b..502d2987195bc 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -76,5 +76,6 @@ Do not modify directly.* | Sub | ai.onnx(7-12,13,14+) | | | Tan | ai.onnx(7+) | | | ThresholdedRelu | ai.onnx(10+) | | +| Tile | ai.onnx(6-12,13+) | | | Transpose | ai.onnx(1-12,13+) | need perf optimization | | Unsqueeze | ai.onnx(1-10,11-12,13+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index c1ef143318503..9f0fd8bb07a7f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -13,6 +13,14 @@ export const tileProgramMetadata = { inputTypes: [GpuDataType.default] }; +const getRepeats = (repeatsTensorView: TensorView): readonly number[] => { + const repeats: number[] = []; + if (repeatsTensorView.dims[0] > 0) { + repeatsTensorView.getBigInt64Array().forEach(v => repeats.push(Number(v))); + } + return repeats; +}; + const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { throw new Error('Tile requires 2 inputs.'); @@ -31,10 +39,7 @@ const validateInputs = (inputs: readonly TensorView[]): void => { throw new Error('Tile `repeats` input should be 1-D'); } - const repeats: number[] = []; - if (inputs[1].dims[0] > 0) { - inputs[1].getBigInt64Array().forEach(v => repeats.push(Number(v))); - } + const repeats: readonly number[] = getRepeats(inputs[1]); if (repeats.length !== inputs[0].dims.length) { throw new Error('Tile `repeats` input should have same number of elements as rank of input data tensor'); @@ -58,10 +63,7 @@ export const createTileProgramInfo = const dataType = 'f32'; const inputShape = inputs[0].dims; - const repeats: number[] = []; - if (inputs[1].dims[0] > 0) { - inputs[1].getBigInt64Array().forEach(v => repeats.push(Number(v))); - } + const repeats: readonly number[] = getRepeats(inputs[1]); const outputShape = getOutputShape(inputShape, repeats); const outputSize = ShapeUtil.size(outputShape); @@ -71,33 +73,36 @@ export const createTileProgramInfo = const isl = inputShape.length; const calculateInputIndexImpl = (): string => ` - fn calculateInputIndex(inputIndices: ${inputIndicesHelper.iType}, outputIndices: ${ - outputIndicesHelper.iType}) -> void { + fn calculateInputIndex(outputIndices: ${outputIndicesHelper.iType}) -> ${inputIndicesHelper.iType} { + ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} + for (var i = 0; i < ${isl}; i++) { // TODO: IndicesHelper should offer uniform way to get/set indices for all ranks inputIndices${isl >= 2 ? '[i]' : ''} = (outputIndices${isl >= 2 ? '[i]' : ''} % inputShape[i]); } - }`; + + return inputIndices; + }`; const getShaderSource = (shaderHelper: ShaderHelper) => ` - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - ${calculateInputIndexImpl()}; - @group(0) @binding(0) var input : array<${dataType}>; - @group(0) @binding(1) var output : array<${dataType}>; + const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); + ${calculateInputIndexImpl()}; + @group(0) @binding(0) var input : array<${dataType}>; + @group(0) @binding(1) var output : array<${dataType}>; - ${outputIndicesHelper.o2iImpl} - ${inputIndicesHelper.i2oImpl} + ${outputIndicesHelper.o2iImpl} + ${inputIndicesHelper.i2oImpl} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} - ${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')} - ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} - calculateInputIndexImpl(inputIndices, outputIndices); - output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}]; - }`; + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')} + ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} + inputIndices = calculateInputIndex(outputIndices); + output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}]; + }`; return { ...tileProgramMetadata, @@ -109,7 +114,11 @@ export const createTileProgramInfo = export const tile = (context: ComputeContext): void => { validateInputs(context.inputs); - const cacheHint = context.inputs.map(x => x.dims.toString()).join('_'); + // const cacheHint = context.inputs[0].dims.toString(); + + const repeats: readonly number[] = getRepeats(context.inputs[1]); + + const cacheHint = context.inputs[0].dims.toString().concat(repeats.toString()); context.compute( {...tileProgramMetadata, cacheHint, get: () => createTileProgramInfo(tileProgramMetadata, context.inputs)}, {inputs: [0]}); diff --git a/js/web/test/data/ops/tile.jsonc b/js/web/test/data/ops/tile.jsonc new file mode 100644 index 0000000000000..3b1794d1d27b5 --- /dev/null +++ b/js/web/test/data/ops/tile.jsonc @@ -0,0 +1,147 @@ +[ + { + "name": "Tile 2D - float32", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[1,2]", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 1, 2, 3, 4, 3, 4], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Tile 2D - int32", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[1,2]", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "int32" + }, + { + "data": [1, 2], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 1, 2, 3, 4, 3, 4], + "dims": [2, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Tile 2D - uint32", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[1,2]", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "uint32" + }, + { + "data": [2, 1], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 1, 2, 3, 4], + "dims": [4, 2], + "type": "uint32" + } + ] + } + ] + }, + { + "name": "Tile 1D - float32", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [1, 2], + "dims": [2], + "type": "float32" + }, + { + "data": [4], + "dims": [1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 1, 2, 1, 2, 1, 2], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Tile 2D all dims", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [0, 1, 2, 3], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [2, 2], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3], + "dims": [4, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 70e110d4988db..7dc289ca0968e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1258,8 +1258,8 @@ "test_thresholdedrelu_default", "test_thresholdedrelu_example", "test_thresholdedrelu", - // // "test_tile_precomputed", - // // "test_tile", + "test_tile_precomputed", + "test_tile", // // "test_top_k_negative_axis", // // "test_top_k_smallest", // // "test_top_k", @@ -1360,6 +1360,7 @@ "sqrt.jsonc", "sub.jsonc", "tan.jsonc", + "tile.jsonc", "transpose.jsonc", "transpose_int32_uint32.jsonc" //"xor.jsonc" diff --git a/onnxruntime/core/providers/js/operators/tile.cc b/onnxruntime/core/providers/js/operators/tile.cc index 6a05e5c9f7f41..eb101c75b01fa 100644 --- a/onnxruntime/core/providers/js/operators/tile.cc +++ b/onnxruntime/core/providers/js/operators/tile.cc @@ -13,7 +13,9 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) .InputMemoryType(OrtMemTypeCPU, 1), Tile); @@ -23,7 +25,9 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) .InputMemoryType(OrtMemTypeCPU, 1), Tile); } // namespace js From 6940691d07e5763f840cb4faf431466eea836ddc Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 11 Aug 2023 11:56:20 -0700 Subject: [PATCH 4/7] Refine --- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 3 --- 1 file changed, 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 9f0fd8bb07a7f..b724f77ffbcbd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -114,10 +114,7 @@ export const createTileProgramInfo = export const tile = (context: ComputeContext): void => { validateInputs(context.inputs); - // const cacheHint = context.inputs[0].dims.toString(); - const repeats: readonly number[] = getRepeats(context.inputs[1]); - const cacheHint = context.inputs[0].dims.toString().concat(repeats.toString()); context.compute( {...tileProgramMetadata, cacheHint, get: () => createTileProgramInfo(tileProgramMetadata, context.inputs)}, From ee4712dbefe49e4f13f73a3643ebc1a2670e1a58 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Fri, 11 Aug 2023 20:01:26 -0700 Subject: [PATCH 5/7] Merge main and onboard new indices helper --- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 60 +++++++------------------ 1 file changed, 17 insertions(+), 43 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index b724f77ffbcbd..1fd7ad656f790 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -6,20 +6,16 @@ import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {createIndicesHelper, ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper} from './common'; export const tileProgramMetadata = { name: 'Tile', inputTypes: [GpuDataType.default] }; -const getRepeats = (repeatsTensorView: TensorView): readonly number[] => { - const repeats: number[] = []; - if (repeatsTensorView.dims[0] > 0) { - repeatsTensorView.getBigInt64Array().forEach(v => repeats.push(Number(v))); - } - return repeats; -}; +const getRepeats = (repeatsTensorView: TensorView): readonly number[] => + Array.from(repeatsTensorView.getBigInt64Array(), Number); + const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -58,51 +54,29 @@ const getOutputShape = (inputShape: readonly number[], repeats: readonly number[ export const createTileProgramInfo = (tileProgramMetadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { - // We currently only support 4-byte element tensors, so using f32 here is safe - // TODO: support other data types for Tile - const dataType = 'f32'; const inputShape = inputs[0].dims; - const repeats: readonly number[] = getRepeats(inputs[1]); - const outputShape = getOutputShape(inputShape, repeats); const outputSize = ShapeUtil.size(outputShape); - const inputIndicesHelper = createIndicesHelper('input', inputShape); - const outputIndicesHelper = createIndicesHelper('output', outputShape); - - const isl = inputShape.length; - const calculateInputIndexImpl = (): string => ` - fn calculateInputIndex(outputIndices: ${outputIndicesHelper.iType}) -> ${inputIndicesHelper.iType} { - ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} - - for (var i = 0; i < ${isl}; i++) { - // TODO: IndicesHelper should offer uniform way to get/set indices for all ranks - inputIndices${isl >= 2 ? '[i]' : ''} = (outputIndices${isl >= 2 ? '[i]' : ''} % inputShape[i]); - } - - return inputIndices; - }`; + const dataType = inputs[0].dataType; + const input = inputVariable('input', dataType, inputShape); + const output = outputVariable('output', dataType, outputShape); const getShaderSource = (shaderHelper: ShaderHelper) => ` - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - ${calculateInputIndexImpl()}; - @group(0) @binding(0) var input : array<${dataType}>; - @group(0) @binding(1) var output : array<${dataType}>; - - ${outputIndicesHelper.o2iImpl} - ${inputIndicesHelper.i2oImpl} - + ${shaderHelper.declareVariables(input, output)} + ${output.impl('offsetToIndices')} + ${input.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - - ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} - ${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')} - ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} - inputIndices = calculateInputIndex(outputIndices); - output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}]; - }`; + let outputIndices = ${output.offsetToIndices('global_idx')}; + var inputIndices: ${input.type.indices}; + for (var i = 0; i < ${inputShape.length}; i++) { + ${input.indicesSet('inputIndices', 'i', output.indicesGet('outputIndices', 'i').concat('% inputShape[i]'))} + } + ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} + }`; return { ...tileProgramMetadata, From bed93bcfe3985b373d32908e2eb5a97bfa4ae848 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 17 Aug 2023 18:03:48 -0700 Subject: [PATCH 6/7] More fixes --- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 1fd7ad656f790..2b80ce173245b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -64,7 +64,7 @@ export const createTileProgramInfo = const output = outputVariable('output', dataType, outputShape); const getShaderSource = (shaderHelper: ShaderHelper) => ` - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); + const inputShape = ${input.indices(...inputShape)}; ${shaderHelper.declareVariables(input, output)} ${output.impl('offsetToIndices')} ${input.impl('indicesToOffset', 'get')} @@ -73,7 +73,9 @@ export const createTileProgramInfo = let outputIndices = ${output.offsetToIndices('global_idx')}; var inputIndices: ${input.type.indices}; for (var i = 0; i < ${inputShape.length}; i++) { - ${input.indicesSet('inputIndices', 'i', output.indicesGet('outputIndices', 'i').concat('% inputShape[i]'))} + let inputDimValue = ${output.indicesGet('outputIndices', 'i')} % ${input.indicesGet('inputShape', 'i')}; + + ${input.indicesSet('inputIndices', 'i', 'inputDimValue')} } ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; @@ -89,7 +91,7 @@ export const createTileProgramInfo = export const tile = (context: ComputeContext): void => { validateInputs(context.inputs); const repeats: readonly number[] = getRepeats(context.inputs[1]); - const cacheHint = context.inputs[0].dims.toString().concat(repeats.toString()); + const cacheHint = repeats.toString(); context.compute( {...tileProgramMetadata, cacheHint, get: () => createTileProgramInfo(tileProgramMetadata, context.inputs)}, {inputs: [0]}); From b58e31b195d149086c4a10c5608c124fb0166ff2 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Thu, 17 Aug 2023 18:54:15 -0700 Subject: [PATCH 7/7] Add type template in kernel def --- onnxruntime/core/providers/js/operators/tile.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/js/operators/tile.cc b/onnxruntime/core/providers/js/operators/tile.cc index eb101c75b01fa..f27b6bae0c607 100644 --- a/onnxruntime/core/providers/js/operators/tile.cc +++ b/onnxruntime/core/providers/js/operators/tile.cc @@ -16,6 +16,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .InputMemoryType(OrtMemTypeCPU, 1), Tile); @@ -28,6 +29,7 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .InputMemoryType(OrtMemTypeCPU, 1), Tile); } // namespace js