Skip to content

Commit

Permalink
[JS/WebGPU] Support Tile operator (microsoft#17123)
Browse files Browse the repository at this point in the history
### Description
As title

### Motivation and Context
Improve WebGPU op coverage
  • Loading branch information
hariharans29 authored and kleiti committed Mar 22, 2024
1 parent e43de82 commit 471fd27
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 2 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,6 @@ Do not modify directly.*
| Tan | ai.onnx(7+) | |
| Tanh | ai.onnx(6-12,13+) | |
| ThresholdedRelu | ai.onnx(10+) | |
| Tile | ai.onnx(6-12,13+) | |
| Transpose | ai.onnx(1-12,13+) | need perf optimization |
| Unsqueeze | ai.onnx(1-10,11-12,13+) | |
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm
import {parseSliceAttributes, slice} from './ops/slice';
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
import {tile} from './ops/tile';
import {parseTransposeAttributes, transpose} from './ops/transpose';
import * as unaryOps from './ops/unary-op';
import {ComputeContext} from './types';
Expand Down Expand Up @@ -94,5 +95,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Tan', [unaryOps.tan]],
['Tanh', [unaryOps.tanh]],
['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]],
['Tile', [tile]],
['Transpose', [transpose, parseTransposeAttributes]],
]);
98 changes: 98 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/tile.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';

export const tileProgramMetadata = {
name: 'Tile',
inputTypes: [GpuDataType.default]
};

const getRepeats = (repeatsTensorView: TensorView): readonly number[] =>
Array.from(repeatsTensorView.getBigInt64Array(), Number);


const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('Tile requires 2 inputs.');
}

if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.int32 &&
inputs[0].dataType !== DataType.uint32) {
throw new Error('Tile only support float, int32, and uint32 data types');
}

if (inputs[1].dataType !== DataType.int64) {
throw new Error('Tile `repeats` input should be of int64 data type');
}

if (inputs[1].dims.length !== 1) {
throw new Error('Tile `repeats` input should be 1-D');
}

const repeats: readonly number[] = getRepeats(inputs[1]);

if (repeats.length !== inputs[0].dims.length) {
throw new Error('Tile `repeats` input should have same number of elements as rank of input data tensor');
}
};

const getOutputShape = (inputShape: readonly number[], repeats: readonly number[]): readonly number[] => {
const outputShape: number[] = [];

for (let i = 0; i < inputShape.length; ++i) {
outputShape.push(inputShape[i] * repeats[i]);
}

return outputShape;
};

export const createTileProgramInfo =
(tileProgramMetadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => {
const inputShape = inputs[0].dims;
const repeats: readonly number[] = getRepeats(inputs[1]);
const outputShape = getOutputShape(inputShape, repeats);
const outputSize = ShapeUtil.size(outputShape);

const dataType = inputs[0].dataType;
const input = inputVariable('input', dataType, inputShape);
const output = outputVariable('output', dataType, outputShape);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputShape = ${input.indices(...inputShape)};
${shaderHelper.declareVariables(input, output)}
${output.impl('offsetToIndices')}
${input.impl('indicesToOffset', 'get')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};
var inputIndices: ${input.type.indices};
for (var i = 0; i < ${inputShape.length}; i++) {
let inputDimValue = ${output.indicesGet('outputIndices', 'i')} % ${input.indicesGet('inputShape', 'i')};
${input.indicesSet('inputIndices', 'i', 'inputDimValue')}
}
${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
}`;

return {
...tileProgramMetadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

export const tile = (context: ComputeContext): void => {
validateInputs(context.inputs);
const repeats: readonly number[] = getRepeats(context.inputs[1]);
const cacheHint = repeats.toString();
context.compute(
{...tileProgramMetadata, cacheHint, get: () => createTileProgramInfo(tileProgramMetadata, context.inputs)},
{inputs: [0]});
};
147 changes: 147 additions & 0 deletions js/web/test/data/ops/tile.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
[
{
"name": "Tile 2D - float32",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[1,2]",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "float32"
},
{
"data": [1, 2],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 2, 1, 2, 3, 4, 3, 4],
"dims": [2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "Tile 2D - int32",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[1,2]",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "int32"
},
{
"data": [1, 2],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 2, 1, 2, 3, 4, 3, 4],
"dims": [2, 4],
"type": "int32"
}
]
}
]
},
{
"name": "Tile 2D - uint32",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[1,2]",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "uint32"
},
{
"data": [2, 1],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 2, 3, 4, 1, 2, 3, 4],
"dims": [4, 2],
"type": "uint32"
}
]
}
]
},
{
"name": "Tile 1D - float32",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[2]",
"inputs": [
{
"data": [1, 2],
"dims": [2],
"type": "float32"
},
{
"data": [4],
"dims": [1],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 2, 1, 2, 1, 2, 1, 2],
"dims": [8],
"type": "float32"
}
]
}
]
},
{
"name": "Tile 2D all dims",
"operator": "Tile",
"attributes": [],
"cases": [
{
"name": "T[2]",
"inputs": [
{
"data": [0, 1, 2, 3],
"dims": [2, 2],
"type": "float32"
},
{
"data": [2, 2],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3],
"dims": [4, 4],
"type": "float32"
}
]
}
]
}
]
5 changes: 3 additions & 2 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1258,8 +1258,8 @@
"test_thresholdedrelu_default",
"test_thresholdedrelu_example",
"test_thresholdedrelu",
// // "test_tile_precomputed",
// // "test_tile",
"test_tile_precomputed",
"test_tile",
// // "test_top_k_negative_axis",
// // "test_top_k_smallest",
// // "test_top_k",
Expand Down Expand Up @@ -1363,6 +1363,7 @@
"sqrt.jsonc",
"sub.jsonc",
"tan.jsonc",
"tile.jsonc",
"transpose.jsonc",
"transpose_int32_uint32.jsonc"
//"xor.jsonc"
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization);
Expand Down Expand Up @@ -516,6 +518,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tile)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization)>,
Expand Down
36 changes: 36 additions & 0 deletions onnxruntime/core/providers/js/operators/tile.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/js/js_kernel.h"
#include "tile.h"

namespace onnxruntime {
namespace js {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Tile,
kOnnxDomain,
6,
12,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()})
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPU, 1),
Tile);

ONNX_OPERATOR_KERNEL_EX(
Tile,
kOnnxDomain,
13,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()})
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPU, 1),
Tile);
} // namespace js
} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/js/operators/tile.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/js/js_kernel.h"

namespace onnxruntime {
namespace js {

JSEP_KERNEL_IMPL(Tile, Tile);

} // namespace js
} // namespace onnxruntime

0 comments on commit 471fd27

Please sign in to comment.