Skip to content

Commit

Permalink
gather
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 6627349 commit 25c9d2a
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 4 deletions.
3 changes: 2 additions & 1 deletion js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import {OpSet} from '../../opset';

import * as binaryOps from './ops/binary-op';
import {gather, parseGatherAttributes} from './ops/gather';
import {reshape} from './ops/reshape';
import * as unaryOps from './ops/unary-op';
import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze';
Expand All @@ -28,7 +29,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['Flatten', '', '1+', flatten, parseFlattenAttributes],
['Floor', '', '6+', unaryOps.floor],
// ['FusedConv', 'com.microsoft', '1+', conv, parseConvAttributes],
// ['Gather', '', '1+', gather, parseGatherAttributes],
['Gather', '', '1+', gather, parseGatherAttributes],
// ['Gemm', '', '7-10', gemm, parseGemmAttributesV7],
// ['Gemm', '', '11+', gemm, parseGemmAttributesV11],
// ['GlobalAveragePool', '', '1+', globalAveragePool, parseGlobalAveragePoolAttributes],
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/onnxjs/backends/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export const createIndicesHelper = (name: string, shape: readonly number[]) => {

const i2oImpl = shape.length < 2 ? '' : `
fn ih_i2o_${name}(indices: ptr<function, ${iType}>) -> u32 {
return ${offsets.length > 0 ? offsets.join('+') : '0u'}
return ${offsets.length > 0 ? offsets.join('+') : '0u'};
}`;

const i2oExpression = (varIndices: string) => shape.length < 2 ? varIndices : `ih_i2o_${name}(&${varIndices})`;
Expand Down
130 changes: 130 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key';
import {Graph} from '../../../graph';
import {NUMBER_TYPES, OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {createIndicesHelper, WORKGROUP_SIZE} from './common';

interface GatherAttributes extends AttributeWithCacheKey {
readonly axis: number;
}

export const gather = async(
inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: GatherAttributes): Promise<Tensor[]> => {
validateInputs(inputs, attributes.axis);
return inferenceHandler.run(createGatherProgramInfoLoader(inputs, attributes), inputs);
};

export const parseGatherAttributes: OperatorInitialization<GatherAttributes> = (node: Graph.Node): GatherAttributes =>
createAttributeWithCacheKey({axis: node.attributes.getInt('axis', 0)});

const gatherProgramMetadata = {
name: 'Gather',
inputTypes: [GpuDataType.default, GpuDataType.default]
};

const createGatherProgramInfo =
(metadata: ProgramMetadata, inputs: Tensor[], axis: number, dataType = 'f32'): ProgramInfo => {
const dataShape = inputs[0].dims.slice();
const indicesShape = inputs[1].dims.slice();
const outputShape = new Array(dataShape.length + indicesShape.length - 1);

axis = ShapeUtil.normalizeAxis(axis, dataShape.length);
const indexCopyOps: string[] = [];
if (indicesShape.length > 1) {
indexCopyOps.push('indicesIdx[0] = 0u;');
} else {
indexCopyOps.push('indicesIdx = 0u;');
}
for (let i = 0; i < outputShape.length; i++) {
// outputShape is divided into three parts: A, B, C
// |0 axis| axis + indicesShape.length | end|
// | A | B | C |
//
// dataIdx: [A, inputs[1][B], C]
const outputIdxLValue = outputShape.length > 1 ? `outputIdx[${i}]` : 'outputIdx';
if (i < axis) { // A
const dataIdxLValue = dataShape.length > 1 ? `dataIdx[${i}]` : 'dataIdx';
outputShape[i] = dataShape[i];
indexCopyOps.push(`${dataIdxLValue} = ${outputIdxLValue};`);
} else {
if (i < axis + indicesShape.length) { // B
const indicesIdxLValue = indicesShape.length > 1 ? `indicesIdx[${i - axis}]` : 'indicesIdx';
outputShape[i] = indicesShape[i - axis];
indexCopyOps.push(`${indicesIdxLValue} = ${outputIdxLValue};`);
} else { // C
const dataIdxLValue = dataShape.length > 1 ? `dataIdx[${i - indicesShape.length + 1}]` : 'dataIdx';
outputShape[i] = dataShape[i - indicesShape.length + 1]; // skip 1 for axis
indexCopyOps.push(`${dataIdxLValue} = ${outputIdxLValue};`);
}
}
}
const outputSize = ShapeUtil.size(outputShape);
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const dataIndicesHelper = createIndicesHelper('data', dataShape);
const indicesIndicesHelper = createIndicesHelper('indices', indicesShape);

const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
@group(0) @binding(0) var<storage, read> data : array<${dataType}>;
@group(0) @binding(1) var<storage, read> indices : array<i32>;
@group(0) @binding(2) var<storage, write> output : array<${dataType}>;
${outputIndicesHelper.o2iImpl}
${indicesIndicesHelper.i2oImpl}
${dataIndicesHelper.i2oImpl}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
if (global_id.x >= ${outputSize}u) {
return;
}
${outputIndicesHelper.indicesVariableDeclaration('outputIdx')}
${outputIndicesHelper.o2iCall('global_id.x', 'outputIdx')}
${dataIndicesHelper.indicesVariableDeclaration('dataIdx')}
${indicesIndicesHelper.indicesVariableDeclaration('indicesIdx')}
${indexCopyOps.join('\n ')}
let idx = indices[${indicesIndicesHelper.i2oExpression('indicesIdx')}];
dataIdx${dataShape.length > 1 ? `[${axis}]` : ''} = u32(select(idx, idx + ${dataShape[axis]}, idx < 0));
output[global_id.x] = data[${dataIndicesHelper.i2oExpression('dataIdx')}];
}`;
return {
...metadata,
outputs: [{dims: outputShape, type: inputs[0].type, gpuDataType: GpuDataType.default}],
shaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

const createGatherProgramInfoLoader = (inputs: Tensor[], attributes: GatherAttributes): ProgramInfoLoader => {
const metadata = {...gatherProgramMetadata, cacheHint: attributes.cacheKey};
return {...metadata, get: () => createGatherProgramInfo(metadata, inputs, attributes.axis)};
};

const validateInputs = (inputs: Tensor[], axis: number): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('Gather requires 2 inputs.');
}
const tensorRank = inputs[0].dims.length;
if (tensorRank < 1) {
throw new Error('Invalid input shape.');
}
if (axis < -tensorRank || axis > tensorRank - 1) {
throw new Error('Invalid axis.');
}
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
throw new Error('Invaid input type.');
}
if (inputs[1].type !== 'int32') {
throw new Error('Invaid input type.');
}
};
4 changes: 2 additions & 2 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@
// "test_flatten_axis2",
// "test_flatten_axis3",
// "test_flatten_default_axis",
// "test_gather_0",
// "test_gather_1",
"test_gather_0",
"test_gather_1",
// "test_gemm_nobroadcast",
// "test_gemm_broadcast",
// "test_globalaveragepool_precomputed",
Expand Down

0 comments on commit 25c9d2a

Please sign in to comment.