From 471fd2707b02d5fd78bd0cf77aa361fe029123d8 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Fri, 18 Aug 2023 10:07:21 -0700 Subject: [PATCH] [JS/WebGPU] Support Tile operator (#17123) ### Description As title ### Motivation and Context Improve WebGPU op coverage --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 98 ++++++++++++ js/web/test/data/ops/tile.jsonc | 147 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 5 +- .../providers/js/js_execution_provider.cc | 4 + .../core/providers/js/operators/tile.cc | 36 +++++ .../core/providers/js/operators/tile.h | 14 ++ 8 files changed, 305 insertions(+), 2 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/tile.ts create mode 100644 js/web/test/data/ops/tile.jsonc create mode 100644 onnxruntime/core/providers/js/operators/tile.cc create mode 100644 onnxruntime/core/providers/js/operators/tile.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index f4f9d014de55d..39b9b035b5169 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -79,5 +79,6 @@ Do not modify directly.* | Tan | ai.onnx(7+) | | | Tanh | ai.onnx(6-12,13+) | | | ThresholdedRelu | ai.onnx(10+) | | +| Tile | ai.onnx(6-12,13+) | | | Transpose | ai.onnx(1-12,13+) | need perf optimization | | Unsqueeze | ai.onnx(1-10,11-12,13+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 15a5bdb86c3d0..d7b76e8ddeb85 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -19,6 +19,7 @@ import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm import {parseSliceAttributes, slice} from './ops/slice'; import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; +import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; import {ComputeContext} from './types'; @@ -94,5 +95,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Tan', [unaryOps.tan]], ['Tanh', [unaryOps.tanh]], ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], + ['Tile', [tile]], ['Transpose', [transpose, parseTransposeAttributes]], ]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts new file mode 100644 index 0000000000000..2b80ce173245b --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; + +export const tileProgramMetadata = { + name: 'Tile', + inputTypes: [GpuDataType.default] +}; + +const getRepeats = (repeatsTensorView: TensorView): readonly number[] => + Array.from(repeatsTensorView.getBigInt64Array(), Number); + + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('Tile requires 2 inputs.'); + } + + if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.int32 && + inputs[0].dataType !== DataType.uint32) { + throw new Error('Tile only support float, int32, and uint32 data types'); + } + + if (inputs[1].dataType !== DataType.int64) { + throw new Error('Tile `repeats` input should be of int64 data type'); + } + + if (inputs[1].dims.length !== 1) { + throw new Error('Tile `repeats` input should be 1-D'); + } + + const repeats: readonly number[] = getRepeats(inputs[1]); + + if (repeats.length !== inputs[0].dims.length) { + throw new Error('Tile `repeats` input should have same number of elements as rank of input data tensor'); + } +}; + +const getOutputShape = (inputShape: readonly number[], repeats: readonly number[]): readonly number[] => { + const outputShape: number[] = []; + + for (let i = 0; i < inputShape.length; ++i) { + outputShape.push(inputShape[i] * repeats[i]); + } + + return outputShape; +}; + +export const createTileProgramInfo = + (tileProgramMetadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const inputShape = inputs[0].dims; + const repeats: readonly number[] = getRepeats(inputs[1]); + const outputShape = getOutputShape(inputShape, repeats); + const outputSize = ShapeUtil.size(outputShape); + + const dataType = inputs[0].dataType; + const input = inputVariable('input', dataType, inputShape); + const output = outputVariable('output', dataType, outputShape); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const inputShape = ${input.indices(...inputShape)}; + ${shaderHelper.declareVariables(input, output)} + ${output.impl('offsetToIndices')} + ${input.impl('indicesToOffset', 'get')} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let outputIndices = ${output.offsetToIndices('global_idx')}; + var inputIndices: ${input.type.indices}; + for (var i = 0; i < ${inputShape.length}; i++) { + let inputDimValue = ${output.indicesGet('outputIndices', 'i')} % ${input.indicesGet('inputShape', 'i')}; + + ${input.indicesSet('inputIndices', 'i', 'inputDimValue')} + } + ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} + }`; + + return { + ...tileProgramMetadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +export const tile = (context: ComputeContext): void => { + validateInputs(context.inputs); + const repeats: readonly number[] = getRepeats(context.inputs[1]); + const cacheHint = repeats.toString(); + context.compute( + {...tileProgramMetadata, cacheHint, get: () => createTileProgramInfo(tileProgramMetadata, context.inputs)}, + {inputs: [0]}); +}; diff --git a/js/web/test/data/ops/tile.jsonc b/js/web/test/data/ops/tile.jsonc new file mode 100644 index 0000000000000..3b1794d1d27b5 --- /dev/null +++ b/js/web/test/data/ops/tile.jsonc @@ -0,0 +1,147 @@ +[ + { + "name": "Tile 2D - float32", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[1,2]", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 1, 2, 3, 4, 3, 4], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Tile 2D - int32", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[1,2]", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "int32" + }, + { + "data": [1, 2], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 1, 2, 3, 4, 3, 4], + "dims": [2, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Tile 2D - uint32", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[1,2]", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "uint32" + }, + { + "data": [2, 1], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 1, 2, 3, 4], + "dims": [4, 2], + "type": "uint32" + } + ] + } + ] + }, + { + "name": "Tile 1D - float32", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [1, 2], + "dims": [2], + "type": "float32" + }, + { + "data": [4], + "dims": [1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 1, 2, 1, 2, 1, 2], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Tile 2D all dims", + "operator": "Tile", + "attributes": [], + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [0, 1, 2, 3], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [2, 2], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3], + "dims": [4, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index da70c08372123..5ee62f9bd1ede 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1258,8 +1258,8 @@ "test_thresholdedrelu_default", "test_thresholdedrelu_example", "test_thresholdedrelu", - // // "test_tile_precomputed", - // // "test_tile", + "test_tile_precomputed", + "test_tile", // // "test_top_k_negative_axis", // // "test_top_k_smallest", // // "test_top_k", @@ -1363,6 +1363,7 @@ "sqrt.jsonc", "sub.jsonc", "tan.jsonc", + "tile.jsonc", "transpose.jsonc", "transpose_int32_uint32.jsonc" //"xor.jsonc" diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index d0d8d53c4e61f..d2e58fff83198 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -288,6 +288,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); @@ -516,6 +518,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/tile.cc b/onnxruntime/core/providers/js/operators/tile.cc new file mode 100644 index 0000000000000..f27b6bae0c607 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/tile.cc @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "tile.h" + +namespace onnxruntime { +namespace js { +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Tile, + kOnnxDomain, + 6, + 12, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1), + Tile); + +ONNX_OPERATOR_KERNEL_EX( + Tile, + kOnnxDomain, + 13, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1), + Tile); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/tile.h b/onnxruntime/core/providers/js/operators/tile.h new file mode 100644 index 0000000000000..b3eac0ef828cd --- /dev/null +++ b/onnxruntime/core/providers/js/operators/tile.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +JSEP_KERNEL_IMPL(Tile, Tile); + +} // namespace js +} // namespace onnxruntime