Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent e871138 commit a8def8e
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 0 deletions.
48 changes: 48 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/fuse-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Attribute} from '../../../attribute';
import {MAX_CLIP, MIN_CLIP} from '../../../util';
import {GlslValueFunction} from '../glsl-definitions';

import {glslClip, glslRelu, glslSigmoid} from './unary-op';

export interface InternalActivationAttributes {
readonly activation: string;
readonly clipMin?: number;
readonly clipMax?: number;
readonly activationCacheKey: string;
}

export function getActicationSnippet(attributes: InternalActivationAttributes) {
let func: GlslValueFunction;
switch (attributes.activation) {
case 'Relu':
func = glslRelu();
break;
case 'Sigmoid':
func = glslSigmoid();
break;
case 'Clip':
func = glslClip(attributes.clipMin!, attributes.clipMax!);
break;
// TODO: adding other activations that can be fused.
default:
return {activationFunction: '', applyActivation: ''};
}

const activationName = func.name;
const activationFunction = func.body;
const applyActivation = `value = ${activationName}_(value);`;
return {activationFunction, applyActivation};
}

export const parseInternalActivationAttributes = (attributes: Attribute): InternalActivationAttributes => {
const activation = attributes.getString('activation', '');

if (activation === 'Clip') {
const [clipMin, clipMax] = attributes.getFloats('activation_params', [MIN_CLIP, MAX_CLIP]);
return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`};
}
return {activation, activationCacheKey: activation};
};
138 changes: 138 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Graph} from '../../../graph';
import {OperatorAsyncImplementation, OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {BroadcastUtil, ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {getActicationSnippet, InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';

export const matMul: OperatorAsyncImplementation<InternalActivationAttributes> =
async(inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: InternalActivationAttributes):
Promise<Tensor[]> => {
validateInputs(inputs);

return inferenceHandler.run(createMatmulProgramInfoLoader(inputs, attributes), inputs);
};

export const parseMatMulAttributes: OperatorInitialization<InternalActivationAttributes> =
(node: Graph.Node): InternalActivationAttributes => parseInternalActivationAttributes(node.attributes);

const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
name: 'MatMul',
inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] :
[GpuDataType.default, GpuDataType.default],
cacheHint
});

function createMatmulProgramInfo(
metadata: ProgramMetadata, inputs: Tensor[], activationAttributes: InternalActivationAttributes): ProgramInfo {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;
const outputShape = BroadcastUtil.calcShape(aShape, bShape, true);
if (!outputShape) {
throw new Error('Can\'t use matmul on the given tensors');
}
const coordsDataType = getCoordsDataType(outputShape.length);
const allGlChannels = getGlChannels();
const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);

const hasBias = inputs.length > 2;
const processBias = hasBias ? 'value += getBiasForMatmul();' : '';
const getBiasForMatmulSnippet =
hasBias ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, false)}` : '';

const rank = outputShape.length;
const arank = aShape.length;
const brank = bShape.length;
const sharedDim = aShape[aShape.length - 1];
const shaderSource = `
${activationFunction}
${getBiasForMatmulSnippet}
float process(int indices[${rank}]) {
int a[${arank}];
int b[${brank}];
bcastMatmulIndices_A(indices, a);
bcastMatmulIndices_B(indices, b);
float value;
for (int k=0; k<${sharedDim}; ++k) {
a[${arank - 1}] = k;
b[${brank - 2}] = k;
value += _A(a) * _B(b);
}
${processBias}
${applyActivation}
return value;
}`;
return {
...metadata,
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked},
shaderSource,
};
}

export function createMatmulProgramInfoLoader(
inputs: Tensor[], activationAttributes: InternalActivationAttributes): ProgramInfoLoader {
const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey);
return {...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes)};
}

const validateInputs = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('MatMul requires 2 inputs.');
}

if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) {
throw new Error('shared dimension does not match.');
}

if ((inputs[0].type !== 'float32' && inputs[0].type !== 'float64') ||
(inputs[1].type !== 'float32' && inputs[1].type !== 'float64')) {
throw new Error('inputs should be float type');
}

if (inputs[0].type !== inputs[1].type) {
throw new Error('inputs types should match');
}
};

export function getBiasForMatmul(
coordsDataType: string, allGlChannels: readonly string[], inShape: readonly number[], outShape: readonly number[],
isPacked: boolean): string {
let unpackedCoordsSnippet = '';
const inRank = inShape.length;
const outRank = outShape.length;
const rankDiff = outRank - inRank;
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
} else {
unpackedCoordsSnippet = inShape.map((s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', ');
}
const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape);
const coordsSnippet = broadcastDims.map(d => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n');
const inSize = ShapeUtil.size(inShape);
const isInputScalar = inSize === 1;
let output = 'vec4(outputValue.xx, outputValue.yy)';
if (isInputScalar) {
output = 'vec4(outputValue.x)';
}
const getBiasForMatmulSource = isPacked ? `
vec4 getBiasForMatmul() {
${coordsDataType} coords = getOutputCoords();
${coordsSnippet}
vec4 outputValue = getBias(${unpackedCoordsSnippet});
return ${output};
}` :
`
float getBiasForMatmul() {
${coordsDataType} coords = getOutputCoords();
${coordsSnippet}
return getBias(coords.x);
}`;

return getBiasForMatmulSource;
}

0 comments on commit a8def8e

Please sign in to comment.