-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {Attribute} from '../../../attribute'; | ||
import {MAX_CLIP, MIN_CLIP} from '../../../util'; | ||
import {GlslValueFunction} from '../glsl-definitions'; | ||
|
||
import {glslClip, glslRelu, glslSigmoid} from './unary-op'; | ||
|
||
export interface InternalActivationAttributes { | ||
readonly activation: string; | ||
readonly clipMin?: number; | ||
readonly clipMax?: number; | ||
readonly activationCacheKey: string; | ||
} | ||
|
||
export function getActicationSnippet(attributes: InternalActivationAttributes) { | ||
let func: GlslValueFunction; | ||
switch (attributes.activation) { | ||
case 'Relu': | ||
func = glslRelu(); | ||
break; | ||
case 'Sigmoid': | ||
func = glslSigmoid(); | ||
break; | ||
case 'Clip': | ||
func = glslClip(attributes.clipMin!, attributes.clipMax!); | ||
break; | ||
// TODO: adding other activations that can be fused. | ||
default: | ||
return {activationFunction: '', applyActivation: ''}; | ||
} | ||
|
||
const activationName = func.name; | ||
const activationFunction = func.body; | ||
const applyActivation = `value = ${activationName}_(value);`; | ||
return {activationFunction, applyActivation}; | ||
} | ||
|
||
export const parseInternalActivationAttributes = (attributes: Attribute): InternalActivationAttributes => { | ||
const activation = attributes.getString('activation', ''); | ||
|
||
if (activation === 'Clip') { | ||
const [clipMin, clipMax] = attributes.getFloats('activation_params', [MIN_CLIP, MAX_CLIP]); | ||
return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; | ||
} | ||
return {activation, activationCacheKey: activation}; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {Graph} from '../../../graph'; | ||
import {OperatorAsyncImplementation, OperatorInitialization} from '../../../operators'; | ||
import {Tensor} from '../../../tensor'; | ||
import {BroadcastUtil, ShapeUtil} from '../../../util'; | ||
import {WebGpuInferenceHandler} from '../inference-handler'; | ||
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; | ||
|
||
import {getActicationSnippet, InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; | ||
|
||
export const matMul: OperatorAsyncImplementation<InternalActivationAttributes> = | ||
async(inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: InternalActivationAttributes): | ||
Promise<Tensor[]> => { | ||
validateInputs(inputs); | ||
|
||
return inferenceHandler.run(createMatmulProgramInfoLoader(inputs, attributes), inputs); | ||
}; | ||
|
||
export const parseMatMulAttributes: OperatorInitialization<InternalActivationAttributes> = | ||
(node: Graph.Node): InternalActivationAttributes => parseInternalActivationAttributes(node.attributes); | ||
|
||
const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ | ||
name: 'MatMul', | ||
inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : | ||
[GpuDataType.default, GpuDataType.default], | ||
cacheHint | ||
}); | ||
|
||
function createMatmulProgramInfo( | ||
metadata: ProgramMetadata, inputs: Tensor[], activationAttributes: InternalActivationAttributes): ProgramInfo { | ||
const aShape = inputs[0].dims; | ||
const bShape = inputs[1].dims; | ||
const outputShape = BroadcastUtil.calcShape(aShape, bShape, true); | ||
if (!outputShape) { | ||
throw new Error('Can\'t use matmul on the given tensors'); | ||
} | ||
const coordsDataType = getCoordsDataType(outputShape.length); | ||
const allGlChannels = getGlChannels(); | ||
const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); | ||
|
||
const hasBias = inputs.length > 2; | ||
const processBias = hasBias ? 'value += getBiasForMatmul();' : ''; | ||
const getBiasForMatmulSnippet = | ||
hasBias ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, false)}` : ''; | ||
|
||
const rank = outputShape.length; | ||
const arank = aShape.length; | ||
const brank = bShape.length; | ||
const sharedDim = aShape[aShape.length - 1]; | ||
const shaderSource = ` | ||
${activationFunction} | ||
${getBiasForMatmulSnippet} | ||
float process(int indices[${rank}]) { | ||
int a[${arank}]; | ||
int b[${brank}]; | ||
bcastMatmulIndices_A(indices, a); | ||
bcastMatmulIndices_B(indices, b); | ||
float value; | ||
for (int k=0; k<${sharedDim}; ++k) { | ||
a[${arank - 1}] = k; | ||
b[${brank - 2}] = k; | ||
value += _A(a) * _B(b); | ||
} | ||
${processBias} | ||
${applyActivation} | ||
return value; | ||
}`; | ||
return { | ||
...metadata, | ||
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, | ||
shaderSource, | ||
}; | ||
} | ||
|
||
export function createMatmulProgramInfoLoader( | ||
inputs: Tensor[], activationAttributes: InternalActivationAttributes): ProgramInfoLoader { | ||
const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); | ||
return {...metadata, get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes)}; | ||
} | ||
|
||
const validateInputs = (inputs: Tensor[]): void => { | ||
if (!inputs || inputs.length !== 2) { | ||
throw new Error('MatMul requires 2 inputs.'); | ||
} | ||
|
||
if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) { | ||
throw new Error('shared dimension does not match.'); | ||
} | ||
|
||
if ((inputs[0].type !== 'float32' && inputs[0].type !== 'float64') || | ||
(inputs[1].type !== 'float32' && inputs[1].type !== 'float64')) { | ||
throw new Error('inputs should be float type'); | ||
} | ||
|
||
if (inputs[0].type !== inputs[1].type) { | ||
throw new Error('inputs types should match'); | ||
} | ||
}; | ||
|
||
export function getBiasForMatmul( | ||
coordsDataType: string, allGlChannels: readonly string[], inShape: readonly number[], outShape: readonly number[], | ||
isPacked: boolean): string { | ||
let unpackedCoordsSnippet = ''; | ||
const inRank = inShape.length; | ||
const outRank = outShape.length; | ||
const rankDiff = outRank - inRank; | ||
if (outRank < 2 && inRank > 0) { | ||
unpackedCoordsSnippet = 'coords'; | ||
} else { | ||
unpackedCoordsSnippet = inShape.map((s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); | ||
} | ||
const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape); | ||
const coordsSnippet = broadcastDims.map(d => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n'); | ||
const inSize = ShapeUtil.size(inShape); | ||
const isInputScalar = inSize === 1; | ||
let output = 'vec4(outputValue.xx, outputValue.yy)'; | ||
if (isInputScalar) { | ||
output = 'vec4(outputValue.x)'; | ||
} | ||
const getBiasForMatmulSource = isPacked ? ` | ||
vec4 getBiasForMatmul() { | ||
${coordsDataType} coords = getOutputCoords(); | ||
${coordsSnippet} | ||
vec4 outputValue = getBias(${unpackedCoordsSnippet}); | ||
return ${output}; | ||
}` : | ||
` | ||
float getBiasForMatmul() { | ||
${coordsDataType} coords = getOutputCoords(); | ||
${coordsSnippet} | ||
return getBias(coords.x); | ||
}`; | ||
|
||
return getBiasForMatmulSource; | ||
} |