-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
29 changed files
with
1,081 additions
and
163 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {DataType} from '../../../wasm-common'; | ||
import {TensorView} from '../../tensor'; | ||
import {ShapeUtil} from '../../util'; | ||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; | ||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; | ||
|
||
import {ShaderHelper} from './common'; | ||
|
||
export interface GatherAttributes extends AttributeWithCacheKey { | ||
axis: number; | ||
} | ||
|
||
const validateInputs = (inputs: readonly TensorView[]): void => { | ||
if (!inputs || inputs.length !== 2) { | ||
throw new Error('Gather requires 2 inputs.'); | ||
} | ||
}; | ||
|
||
const createGatherProgramInfo = | ||
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherAttributes): ProgramInfo => { | ||
const inputShape = inputs[0].dims; | ||
const indicesShape = inputs[1].dims; | ||
|
||
const inputRank = inputShape.length; | ||
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); | ||
|
||
const outputShape = inputShape.slice(0); | ||
outputShape.splice(axis, 1, ...indicesShape); | ||
|
||
const inputDataType = inputs[0].dataType; | ||
const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1); | ||
const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; | ||
const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1; | ||
const blockSize = elementSize * block; | ||
const M = ShapeUtil.sizeToDimension(inputShape, axis); | ||
const N = ShapeUtil.size(indicesShape); | ||
const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize; | ||
const gatheredBatchElements = N * block * elementSize; | ||
const axisDimLimit = inputShape[axis]; | ||
|
||
const inputSize = ShapeUtil.size(inputShape) * elementSize; | ||
const outputSize = ShapeUtil.size(outputShape) * elementSize; | ||
|
||
const totalGathers = M * N; | ||
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits | ||
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor | ||
// Input data will be treated as u32 or two u32 for 8-byte tensors | ||
const getShaderSource = (shaderHelper: ShaderHelper) => ` | ||
const N: u32 = ${N}; | ||
const elementSize: u32 = ${elementSize}; | ||
const indicesElementSize: u32 = ${indicesElementSize}; | ||
@group(0) @binding(0) var<storage, read> input : array<u32>; | ||
@group(0) @binding(1) var<storage, read> inputIndices : array<i32>; | ||
@group(0) @binding(2) var<storage, read_write> output: array<u32>; | ||
${shaderHelper.mainStart()} | ||
let batch: u32 = global_idx / N; | ||
let i: u32 = global_idx % N; | ||
let srcOffsetBatch: u32 = batch * ${dataBatchElements}; | ||
let dstOffsetBatch: u32 = batch * ${gatheredBatchElements}; | ||
var idx = inputIndices[i * indicesElementSize]; | ||
if (idx < 0) { | ||
idx = idx + ${axisDimLimit}; | ||
} | ||
let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize}; | ||
let dstOffset = dstOffsetBatch + i * ${blockSize}; | ||
if (srcOffset >= ${inputSize}) { | ||
return; | ||
} | ||
if (dstOffset >= ${outputSize}) { | ||
return; | ||
} | ||
for (var j: u32 = 0; j < ${blockSize}; j++) { | ||
output[dstOffset + j] = input[srcOffset + j]; | ||
} | ||
}`; | ||
return { | ||
...metadata, | ||
outputs: [ | ||
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, | ||
], | ||
getShaderSource, | ||
dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)}) | ||
}; | ||
}; | ||
|
||
export const parseGatherAttributes = (attributes: Record<string, unknown>): GatherAttributes => | ||
createAttributeWithCacheKey({axis: attributes.axis as number}); | ||
|
||
export const gather = (context: ComputeContext, attributes: GatherAttributes): void => { | ||
const inputs = context.inputs; | ||
validateInputs(inputs); | ||
|
||
const metadata = { | ||
name: 'Gather', | ||
inputTypes: [GpuDataType.default, GpuDataType.default], | ||
cacheHint: attributes.cacheKey + inputs[0].dataType.toString(10) + inputs[1].dataType.toString(10), | ||
}; | ||
|
||
context.compute(createGatherProgramInfo(metadata, context.inputs, attributes)); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
// This file is an example of an operator test file. | ||
// | ||
// In this file, we demonstrate how to write a test file for ONNX operators. | ||
// There are 2 operator tests defined in this file: | ||
// | ||
// - "Simple Abs test example": a simple operator test for Abs operator. This example shows how to write a simple test with minimal properties. | ||
// | ||
// - "Conv2D with padding": a simple operator test for Conv operator with padding. This example shows how to write a test with all optional properties. | ||
// | ||
|
||
// test file starts with an array of test objects. | ||
[ | ||
// this is the first operator test object (Abs example). | ||
{ | ||
"name": "Simple Abs op test example", // name of the test | ||
"operator": "Abs", // OpType of the operator | ||
"cases": [ | ||
// in this example, we only have one test case. | ||
{ | ||
// name of the test case | ||
"name": "3D float32 test", | ||
"inputs": [ | ||
// specify the input tensor | ||
{ | ||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, -1, -2, -3, -4, -5, -6, -7, -8, 101, 102, 103, 104], | ||
"dims": [2, 3, 4], | ||
"type": "float32" | ||
} | ||
], | ||
"outputs": [ | ||
{ | ||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 1, 2, 3, 4, 5, 6, 7, 8, 101, 102, 103, 104], | ||
"dims": [2, 3, 4], | ||
"type": "float32" | ||
} | ||
] | ||
} | ||
] | ||
}, | ||
// this is the second operator test object (Conv example). | ||
{ | ||
// name of the test | ||
"name": "Conv op test example", | ||
|
||
// OpType of the operator | ||
"operator": "Conv", | ||
|
||
// [optional] specify the attributes of the operator | ||
"attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], | ||
|
||
// [optional] specify a regex pattern to match the platform description. | ||
// | ||
// If not specified, the test will run on all platforms. | ||
// Otherwise, the test will only run on platforms that match the pattern. | ||
"platformCondition": "", | ||
|
||
// [optional] specify input shape definitions. | ||
// | ||
// Sometimes, input shape definitions can offer shape information for ONNX Runtime to optimize its inferencing behavior. | ||
// For example, ORT will transform a NCHW Conv operator into a NHWC operator when the input shape is 4 dimensional. | ||
// If the input shape dimension is unknown, ORT will not perform this optimization. | ||
// | ||
// In operator test, we can specify input shape definitions to test the optimized behavior. | ||
// | ||
// The array of input shape definitions should have the same length as the number of model's inputs. | ||
// | ||
"inputShapeDefinitions": [ | ||
// input 0 shape definition. use semantic names to specify the dynamic dimensions. | ||
["__input_0_dim_0", "__input_0_dim_1", "__input_0_dim_2", "__input_0_dim_3"], | ||
// input 1 shape definition. use numbers to specify the static dimensions. | ||
[1, 1, 2, 2] | ||
], | ||
|
||
// [optional] specify the opset of the operator. | ||
"opset": { "domain": "", "version": 13 }, | ||
|
||
// test cases is required. | ||
"cases": [ | ||
{ | ||
"name": "NCHW Conv2D test", | ||
"inputs": [ | ||
{ | ||
"data": [10, 20, 30, 40, 50, 60, 70, 80, 90], | ||
"dims": [1, 1, 3, 3], | ||
"type": "float32" | ||
}, | ||
{ | ||
"data": [1, 2, 3, 4], | ||
"dims": [1, 1, 2, 2], | ||
"type": "float32" | ||
} | ||
], | ||
"outputs": [ | ||
{ | ||
"data": [370, 470, 670, 770], | ||
"dims": [1, 1, 2, 2], | ||
"type": "float32" | ||
} | ||
] | ||
} | ||
] | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.