Skip to content

Commit

Permalink
Merge latest commits from main
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Aug 3, 2023
2 parents 3878e79 + a25d0d2 commit a21251e
Show file tree
Hide file tree
Showing 29 changed files with 1,081 additions and 163 deletions.
8 changes: 7 additions & 1 deletion js/.vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,11 @@
},
"typescript.tsdk": "node_modules/typescript/lib",
"git.detectSubmodules": false,
"cmake.configureOnOpen": false
"cmake.configureOnOpen": false,
"json.schemas": [
{
"fileMatch": ["web/test/data/ops/*.jsonc"],
"url": "./web/test/op-test-schema.json"
}
]
}
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Do not modify directly.*
| Expand | ai.onnx(8-12,13+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| Gelu | com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {expand} from './ops/expand';
import {gather, parseGatherAttributes} from './ops/gather';
import {gelu} from './ops/gelu';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {matMul} from './ops/matmul';
Expand Down Expand Up @@ -51,6 +52,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['Gather', [gather, parseGatherAttributes]],
['Gelu', [gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
Expand Down
107 changes: 107 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';

import {ShaderHelper} from './common';

export interface GatherAttributes extends AttributeWithCacheKey {
axis: number;
}

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('Gather requires 2 inputs.');
}
};

const createGatherProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const indicesShape = inputs[1].dims;

const inputRank = inputShape.length;
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);

const outputShape = inputShape.slice(0);
outputShape.splice(axis, 1, ...indicesShape);

const inputDataType = inputs[0].dataType;
const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1);
const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1;
const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1;
const blockSize = elementSize * block;
const M = ShapeUtil.sizeToDimension(inputShape, axis);
const N = ShapeUtil.size(indicesShape);
const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize;
const gatheredBatchElements = N * block * elementSize;
const axisDimLimit = inputShape[axis];

const inputSize = ShapeUtil.size(inputShape) * elementSize;
const outputSize = ShapeUtil.size(outputShape) * elementSize;

const totalGathers = M * N;
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
// Input data will be treated as u32 or two u32 for 8-byte tensors
const getShaderSource = (shaderHelper: ShaderHelper) => `
const N: u32 = ${N};
const elementSize: u32 = ${elementSize};
const indicesElementSize: u32 = ${indicesElementSize};
@group(0) @binding(0) var<storage, read> input : array<u32>;
@group(0) @binding(1) var<storage, read> inputIndices : array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<u32>;
${shaderHelper.mainStart()}
let batch: u32 = global_idx / N;
let i: u32 = global_idx % N;
let srcOffsetBatch: u32 = batch * ${dataBatchElements};
let dstOffsetBatch: u32 = batch * ${gatheredBatchElements};
var idx = inputIndices[i * indicesElementSize];
if (idx < 0) {
idx = idx + ${axisDimLimit};
}
let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize};
let dstOffset = dstOffsetBatch + i * ${blockSize};
if (srcOffset >= ${inputSize}) {
return;
}
if (dstOffset >= ${outputSize}) {
return;
}
for (var j: u32 = 0; j < ${blockSize}; j++) {
output[dstOffset + j] = input[srcOffset + j];
}
}`;
return {
...metadata,
outputs: [
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)})
};
};

export const parseGatherAttributes = (attributes: Record<string, unknown>): GatherAttributes =>
createAttributeWithCacheKey({axis: attributes.axis as number});

export const gather = (context: ComputeContext, attributes: GatherAttributes): void => {
const inputs = context.inputs;
validateInputs(inputs);

const metadata = {
name: 'Gather',
inputTypes: [GpuDataType.default, GpuDataType.default],
cacheHint: attributes.cacheKey + inputs[0].dataType.toString(10) + inputs[1].dataType.toString(10),
};

context.compute(createGatherProgramInfo(metadata, context.inputs, attributes));
};
12 changes: 7 additions & 5 deletions js/web/script/test-runner-cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ async function main() {
}

const test = testIds && testIds.length > 0 ? allTests[testIds[0]] : undefined;
const condition = test && typeof test !== 'string' ? test.condition : undefined;
const platformCondition = test && typeof test !== 'string' ? test.platformCondition : undefined;

const opsetVersion = folder.split('/')[0];
const category = `node-${opsetVersion}-${backend}`;
Expand All @@ -243,14 +243,16 @@ async function main() {
modelTests = [];
opsetTests.set(category, modelTests);
}
modelTests.push(modelTestFromFolder(path.resolve(TEST_DATA_MODEL_NODE_ROOT, folder), backend, condition, times));
modelTests.push(
modelTestFromFolder(path.resolve(TEST_DATA_MODEL_NODE_ROOT, folder), backend, platformCondition, times));
}

return Array.from(opsetTests.keys()).map(category => ({name: category, tests: opsetTests.get(category)!}));
}

function modelTestFromFolder(
testDataRootFolder: string, backend: string, condition?: Test.Condition, times?: number): Test.ModelTest {
testDataRootFolder: string, backend: string, platformCondition?: Test.PlatformCondition,
times?: number): Test.ModelTest {
if (times === 0) {
npmlog.verbose('TestRunnerCli.Init.Model', `Skip test data from folder: ${testDataRootFolder}`);
return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: []};
Expand Down Expand Up @@ -326,7 +328,7 @@ async function main() {
npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`);
npmlog.verbose('TestRunnerCli.Init.Model', '===============================================================');

return {name: path.basename(testDataRootFolder), condition, modelUrl, backend, cases};
return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases};
}

function tryLocateModelTestFolder(searchPattern: string): string {
Expand Down Expand Up @@ -385,7 +387,7 @@ async function main() {
// field 'verbose' and 'backend' is not set
for (const test of tests) {
test.backend = backend;
test.opsets = test.opsets || [{domain: '', version: MAX_OPSET_VERSION}];
test.opset = test.opset || {domain: '', version: MAX_OPSET_VERSION};
}
npmlog.verbose('TestRunnerCli.Init.Op', 'Finished preparing test data.');
npmlog.verbose('TestRunnerCli.Init.Op', '===============================================================');
Expand Down
103 changes: 103 additions & 0 deletions js/web/test/data/ops/_example.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// This file is an example of an operator test file.
//
// In this file, we demonstrate how to write a test file for ONNX operators.
// There are 2 operator tests defined in this file:
//
// - "Simple Abs test example": a simple operator test for Abs operator. This example shows how to write a simple test with minimal properties.
//
// - "Conv2D with padding": a simple operator test for Conv operator with padding. This example shows how to write a test with all optional properties.
//

// test file starts with an array of test objects.
[
// this is the first operator test object (Abs example).
{
"name": "Simple Abs op test example", // name of the test
"operator": "Abs", // OpType of the operator
"cases": [
// in this example, we only have one test case.
{
// name of the test case
"name": "3D float32 test",
"inputs": [
// specify the input tensor
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, -1, -2, -3, -4, -5, -6, -7, -8, 101, 102, 103, 104],
"dims": [2, 3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 1, 2, 3, 4, 5, 6, 7, 8, 101, 102, 103, 104],
"dims": [2, 3, 4],
"type": "float32"
}
]
}
]
},
// this is the second operator test object (Conv example).
{
// name of the test
"name": "Conv op test example",

// OpType of the operator
"operator": "Conv",

// [optional] specify the attributes of the operator
"attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }],

// [optional] specify a regex pattern to match the platform description.
//
// If not specified, the test will run on all platforms.
// Otherwise, the test will only run on platforms that match the pattern.
"platformCondition": "",

// [optional] specify input shape definitions.
//
// Sometimes, input shape definitions can offer shape information for ONNX Runtime to optimize its inferencing behavior.
// For example, ORT will transform a NCHW Conv operator into a NHWC operator when the input shape is 4 dimensional.
// If the input shape dimension is unknown, ORT will not perform this optimization.
//
// In operator test, we can specify input shape definitions to test the optimized behavior.
//
// The array of input shape definitions should have the same length as the number of model's inputs.
//
"inputShapeDefinitions": [
// input 0 shape definition. use semantic names to specify the dynamic dimensions.
["__input_0_dim_0", "__input_0_dim_1", "__input_0_dim_2", "__input_0_dim_3"],
// input 1 shape definition. use numbers to specify the static dimensions.
[1, 1, 2, 2]
],

// [optional] specify the opset of the operator.
"opset": { "domain": "", "version": 13 },

// test cases is required.
"cases": [
{
"name": "NCHW Conv2D test",
"inputs": [
{
"data": [10, 20, 30, 40, 50, 60, 70, 80, 90],
"dims": [1, 1, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [370, 470, 670, 770],
"dims": [1, 1, 2, 2],
"type": "float32"
}
]
}
]
}
]
6 changes: 3 additions & 3 deletions js/web/test/data/ops/gelu.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
"name": "gelu",
"operator": "Gelu",
"opsets": [{ "domain": "com.microsoft", "version": 1 }],
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [],
"cases": [
{
Expand All @@ -16,7 +16,7 @@
],
"outputs": [
{
"data": [1.0, 0, 0, 2.0],
"data": [0.8413447141647339, -0.04550027847290039, 0, 1.9544997215270996],
"dims": [2, 2],
"type": "float32"
}
Expand All @@ -33,7 +33,7 @@
],
"outputs": [
{
"data": [1.0],
"data": [0.8413447141647339],
"dims": [],
"type": "float32"
}
Expand Down
2 changes: 1 addition & 1 deletion js/web/test/data/ops/pad-big.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
"name": "constant 2D",
"operator": "Pad",
"opsets": [{ "domain": "", "version": 10 }],
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "reflect", "type": "string" },
{ "name": "pads", "data": [0, 0, 1, 1, 0, 0, 1, 1], "type": "ints" }
Expand Down
14 changes: 7 additions & 7 deletions js/web/test/data/ops/pad.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
"name": "constant 2D",
"operator": "Pad",
"opsets": [{ "domain": "", "version": 10 }],
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "constant", "type": "string" },
{ "name": "value", "data": 1.2, "type": "float" },
Expand Down Expand Up @@ -35,7 +35,7 @@
{
"name": "constant 3D",
"operator": "Pad",
"opsets": [{ "domain": "", "version": 10 }],
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "constant", "type": "string" },
{ "name": "value", "data": 2.3, "type": "float" },
Expand Down Expand Up @@ -79,7 +79,7 @@
{
"name": "Reflect 1D",
"operator": "Pad",
"opsets": [{ "domain": "", "version": 10 }],
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "reflect", "type": "string" },
{ "name": "pads", "data": [5, 7], "type": "ints" }
Expand Down Expand Up @@ -107,7 +107,7 @@
{
"name": "Reflect 2D",
"operator": "Pad",
"opsets": [{ "domain": "", "version": 10 }],
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "reflect", "type": "string" },
{ "name": "pads", "data": [3, 2, 2, 5], "type": "ints" }
Expand Down Expand Up @@ -139,7 +139,7 @@
{
"name": "Reflect 3D",
"operator": "Pad",
"opsets": [{ "domain": "", "version": 10 }],
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "reflect", "type": "string" },
{ "name": "pads", "data": [1, 2, 2, 2, 3, 1], "type": "ints" }
Expand Down Expand Up @@ -182,7 +182,7 @@
{
"name": "Edge 2D",
"operator": "Pad",
"opsets": [{ "domain": "", "version": 10 }],
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "edge", "type": "string" },
{ "name": "pads", "data": [3, 2, 2, 3], "type": "ints" }
Expand Down Expand Up @@ -214,7 +214,7 @@
{
"name": "Edge 3D",
"operator": "Pad",
"opsets": [{ "domain": "", "version": 10 }],
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "edge", "type": "string" },
{ "name": "pads", "data": [1, 2, 2, 2, 3, 1], "type": "ints" }
Expand Down
Loading

0 comments on commit a21251e

Please sign in to comment.