Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WebGPU] Support Tile operator #17123

Merged
merged 12 commits into from
Aug 18, 2023
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,6 @@ Do not modify directly.*
| Tan | ai.onnx(7+) | |
| Tanh | ai.onnx(6-12,13+) | |
| ThresholdedRelu | ai.onnx(10+) | |
| Tile | ai.onnx(6-12,13+) | |
| Transpose | ai.onnx(1-12,13+) | need perf optimization |
| Unsqueeze | ai.onnx(1-10,11-12,13+) | |
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {parseResizeAttributes, resize} from './ops/resize';
import {parseSliceAttributes, slice} from './ops/slice';
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
import {tile} from './ops/tile';
import {parseTransposeAttributes, transpose} from './ops/transpose';
import * as unaryOps from './ops/unary-op';
import {ComputeContext} from './types';
Expand Down Expand Up @@ -91,5 +92,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Tan', [unaryOps.tan]],
['Tanh', [unaryOps.tanh]],
['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]],
['Tile', [tile]],
['Transpose', [transpose, parseTransposeAttributes]],
]);
96 changes: 96 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/tile.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';

export const tileProgramMetadata = {
name: 'Tile',
inputTypes: [GpuDataType.default]
};

const getRepeats = (repeatsTensorView: TensorView): readonly number[] =>
Array.from(repeatsTensorView.getBigInt64Array(), Number);


const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('Tile requires 2 inputs.');
}

if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.int32 &&
guschmue marked this conversation as resolved.
Show resolved Hide resolved
inputs[0].dataType !== DataType.uint32) {
throw new Error('Tile only support float, int32, and uint32 data types');
}

if (inputs[1].dataType !== DataType.int64) {
throw new Error('Tile `repeats` input should be of int64 data type');
}

if (inputs[1].dims.length !== 1) {
throw new Error('Tile `repeats` input should be 1-D');
}

const repeats: readonly number[] = getRepeats(inputs[1]);

if (repeats.length !== inputs[0].dims.length) {
throw new Error('Tile `repeats` input should have same number of elements as rank of input data tensor');
}
};

const getOutputShape = (inputShape: readonly number[], repeats: readonly number[]): readonly number[] => {
const outputShape: number[] = [];

for (let i = 0; i < inputShape.length; ++i) {
outputShape.push(inputShape[i] * repeats[i]);
}

return outputShape;
};

export const createTileProgramInfo =
(tileProgramMetadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => {
const inputShape = inputs[0].dims;
const repeats: readonly number[] = getRepeats(inputs[1]);
const outputShape = getOutputShape(inputShape, repeats);
const outputSize = ShapeUtil.size(outputShape);

const dataType = inputs[0].dataType;
const input = inputVariable('input', dataType, inputShape);
const output = outputVariable('output', dataType, outputShape);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
${shaderHelper.declareVariables(input, output)}
${output.impl('offsetToIndices')}
${input.impl('indicesToOffset', 'get')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};
var inputIndices: ${input.type.indices};
for (var i = 0; i < ${inputShape.length}; i++) {
${input.indicesSet('inputIndices', 'i', output.indicesGet('outputIndices', 'i').concat('% inputShape[i]'))}
}
${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
}`;

return {
...tileProgramMetadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

export const tile = (context: ComputeContext): void => {
validateInputs(context.inputs);
const repeats: readonly number[] = getRepeats(context.inputs[1]);
const cacheHint = context.inputs[0].dims.toString().concat(repeats.toString());
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
context.compute(
{...tileProgramMetadata, cacheHint, get: () => createTileProgramInfo(tileProgramMetadata, context.inputs)},
{inputs: [0]});
};
147 changes: 147 additions & 0 deletions js/web/test/data/ops/tile.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
[
{
"name": "Tile 2D - float32",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[1,2]",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "float32"
},
{
"data": [1, 2],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 2, 1, 2, 3, 4, 3, 4],
"dims": [2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "Tile 2D - int32",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[1,2]",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "int32"
},
{
"data": [1, 2],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 2, 1, 2, 3, 4, 3, 4],
"dims": [2, 4],
"type": "int32"
}
]
}
]
},
{
"name": "Tile 2D - uint32",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[1,2]",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "uint32"
},
{
"data": [2, 1],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 2, 3, 4, 1, 2, 3, 4],
"dims": [4, 2],
"type": "uint32"
}
]
}
]
},
{
"name": "Tile 1D - float32",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[2]",
"inputs": [
{
"data": [1, 2],
"dims": [2],
"type": "float32"
},
{
"data": [4],
"dims": [1],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 2, 1, 2, 1, 2, 1, 2],
"dims": [8],
"type": "float32"
}
]
}
]
},
{
"name": "Tile 2D all dims",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[2]",
"inputs": [
{
"data": [0, 1, 2, 3],
"dims": [2, 2],
"type": "float32"
},
{
"data": [2, 2],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3],
"dims": [4, 4],
"type": "float32"
}
]
}
]
}
]
5 changes: 3 additions & 2 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1258,8 +1258,8 @@
"test_thresholdedrelu_default",
"test_thresholdedrelu_example",
"test_thresholdedrelu",
// // "test_tile_precomputed",
// // "test_tile",
"test_tile_precomputed",
"test_tile",
// // "test_top_k_negative_axis",
// // "test_top_k_smallest",
// // "test_top_k",
Expand Down Expand Up @@ -1360,6 +1360,7 @@
"sqrt.jsonc",
"sub.jsonc",
"tan.jsonc",
"tile.jsonc",
"transpose.jsonc",
"transpose_int32_uint32.jsonc"
//"xor.jsonc"
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization);
Expand Down Expand Up @@ -512,6 +514,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization)>,
Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/core/providers/js/operators/tile.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/js/js_kernel.h"
#include "tile.h"

namespace onnxruntime {
namespace js {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Tile,
kOnnxDomain,
6,
12,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()})
.InputMemoryType(OrtMemTypeCPU, 1),
Tile);

ONNX_OPERATOR_KERNEL_EX(
Tile,
kOnnxDomain,
13,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()})
.InputMemoryType(OrtMemTypeCPU, 1),
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
Tile);
} // namespace js
} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/js/operators/tile.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/js/js_kernel.h"

namespace onnxruntime {
namespace js {

JSEP_KERNEL_IMPL(Tile, Tile);

} // namespace js
} // namespace onnxruntime