Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] support customop FastGelu #19392

Merged
merged 5 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Do not modify directly.*
| Erf | ai.onnx(9-12,13+) | |
| Exp | ai.onnx(6-12,13+) | |
| Expand | ai.onnx(8-12,13+) | |
| FastGelu | com.microsoft(1+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| FusedConv | com.microsoft(1+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'
import {cumsum, parseCumSumAttributes} from './ops/cumsum';
import {einsum, parseEinsumAttributes} from './ops/einsum';
import {expand} from './ops/expand';
import {fastGelu} from './ops/fast-gelu';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
Expand Down Expand Up @@ -72,6 +73,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Erf', [unaryOps.erf]],
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['FastGelu', [fastGelu]],
['Floor', [unaryOps.floor]],
['FusedConv', [conv, parseConvAttributes]],
['Gather', [gather, parseGatherAttributes]],
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI

${shaderHelper.declareVariables(input, bias, output)}

${erfImpl(`vec4<${dataType}>`, dataType)}
${erfImpl(dataType)}

${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
Expand Down
69 changes: 69 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common';
import * as unary from './unary-op';

// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias.

const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => {
const dataType = inputTensors[0].dataType;
const outputSize = ShapeUtil.size(inputTensors[0].dims);
const biasLength = ShapeUtil.size(inputTensors[1].dims);
satyajandhyala marked this conversation as resolved.
Show resolved Hide resolved
// can only use vec4 when bias length is multiple of 4
const useVec4 = biasLength % 4 === 0;
const getShaderSource = (shaderHelper: ShaderHelper): string => {
const x = inputVariable('x', dataType, [1], 4);
const bias = inputVariable('bias', dataType, [1], 4);
const y = outputVariable('y', dataType, [1], 4);

const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}];

const singleElementBias = (i: 0|1|2|3) => `
let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size;
let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`;
const biasGetExpression = useVec4 ?
`
let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` :
`${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)}
let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`;

return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)}

${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))}

${shaderHelper.mainStart(WORKGROUP_SIZE)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')}

let x = ${x.getByOffset('global_idx')};
${biasGetExpression}
let x_in = x + bias;
${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))}
}`;
};

return {
name: 'FastGeluWithBias',
shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']},
getShaderSource,
getRunData: (inputs) => ({
outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}],
programUniforms:
[{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}],
dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)}
})
};
};

export const fastGelu = (context: ComputeContext): void => {
if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) {
unary.fastGelu(context);
} else {
context.compute(createFastGeluProgramInfo(context.inputs));
}
};
33 changes: 26 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,23 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
attributes.cacheKey));
};

export const erfImpl = (dataType: string, varType = 'f32') => `
export const erfImpl = (varType = 'f32') => `
const r0: ${varType} = 0.3275911;
const r1: ${varType} = 0.254829592;
const r2: ${varType} = -0.284496736;
const r3: ${varType} = 1.421413741;
const r4: ${varType} = -1.453152027;
const r5: ${varType} = 1.061405429;

fn erf_vf32(v: ${dataType}) -> ${dataType} {
fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> {
let absv = abs(v);
let x = 1.0 / (1.0 + r0 * absv);
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
}`;

export const erf = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType)));
};

export const exp = (context: ComputeContext): void => {
Expand All @@ -209,8 +208,7 @@ export const floor = (context: ComputeContext): void => {
export const gelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl(`vec4<${dataType}>`, dataType)));
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType)));
};

export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
Expand Down Expand Up @@ -278,10 +276,31 @@ export const tan = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan'));
};

export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`;

export const tanh = (context: ComputeContext): void => {
// TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression));
};

export const fastGeluImpl = (varType = 'f32') => `
const fast_gelu_a: ${varType} = 0.5;
const fast_gelu_b: ${varType} = 0.7978845608028654;
const fast_gelu_c: ${varType} = 0.035677408136300125;

fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> {
return ${tanhExpression('v')};
}
`;

export const fastGeluExpression = (x: string) =>
`(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`;

export const fastGelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`));
context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined,
context.inputs[0].dataType));
};

export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
Expand Down
211 changes: 211 additions & 0 deletions js/web/test/data/ops/fast-gelu.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
[
{
"name": "FastGelu test without bias",
"operator": "FastGelu",
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "scalar",
"inputs": [
{
"data": [1],
"dims": [],
"type": "float32"
}
],
"outputs": [
{
"data": [0.841192],
"dims": [],
"type": "float32"
}
]
},
{
"name": "[2x4]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
"dims": [2, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.435415, 0.53057, 0.630432],
"dims": [2, 4],
"type": "float32"
}
]
},
{
"name": "[3x5]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
"dims": [3, 5],
"type": "float32"
}
],
"outputs": [
{
"data": [
0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.841192, 1.9546, 2.99636, 3.99993, 5, 0.950581,
1.0617, 1.17393, 1.28671, 1.39957
],
"dims": [3, 5],
"type": "float32"
}
]
}
]
},
{
"name": "FastGelu test with bias",
"operator": "FastGelu",
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "scalar",
"inputs": [
{
"data": [1],
"dims": [],
"type": "float32"
},
{
"data": [0.5],
"dims": [],
"type": "float32"
}
],
"outputs": [
{
"data": [1.39957],
"dims": [],
"type": "float32"
}
]
},
{
"name": "[2x4], [4]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
"dims": [2, 4],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [4],
"type": "float32"
}
],
"outputs": [
{
"data": [0.950581, 2.16968, 3.29869, 4.39999, 1.39957, 2.58835, 3.69973, 4.8],
"dims": [2, 4],
"type": "float32"
}
]
},
{
"name": "[2x4], [3]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
"dims": [2, 4],
"type": "float32"
},
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [0.950581, 2.16968, 3.29869, 1.28671, 2.48492, 3.59959, 1.62411, 2.79331],
"dims": [2, 4],
"type": "float32"
}
]
},
{
"name": "[3x5], [2]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
"dims": [3, 5],
"type": "float32"
},
{
"data": [2, 3],
"dims": [2],
"type": "float32"
}
],
"outputs": [
{
"data": [
2.06267, 3.19813, 2.27567, 3.39909, 2.48492, 3.99993, 3.99993, 6, 6, 8, 3.09737, 4.19997, 3.29869,
4.39999, 3.49938
],
"dims": [3, 5],
"type": "float32"
}
]
},
{
"name": "[3x5], [7]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
"dims": [3, 5],
"type": "float32"
},
{
"data": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7],
"dims": [7],
"type": "float32"
}
],
"outputs": [
{
"data": [
2.16968, 2.38072, 2.58835, 2.79331, 2.99636, 3.59959, 4.7, 5.1, 6.2, 7.3, 3.49938, 3.69973, 3.89989,
4.09996, 3.59959
],
"dims": [3, 5],
"type": "float32"
}
]
},
{
"name": "[4x4], [8]",
"inputs": [
{
"data": [0.8, -0.5, 0.0, 1, 1.3, 2.1, -0.2, 1.1, 0.5, 0.2, 0.3, -0.6, 3.1, 2.2, -1.1, 0.0],
"dims": [4, 4],
"type": "float32"
},
{
"data": [-0.5, 0.6, 1.2, 2.1, 1.3, -1, 0, 3.1],
"dims": [8],
"type": "float32"
}
],
"outputs": [
{
"data": [
0.185371, 0.0539828, 1.0617, 3.09737, 2.58835, 0.950581, -0.0841486, 4.19997, 0, 0.630432, 1.39957,
1.39957, 4.39999, 1.0617, -0.149419, 3.09737
],
"dims": [4, 4],
"type": "float32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,7 @@
"equal.jsonc",
"exp.jsonc",
"expand.jsonc",
"fast-gelu.jsonc",
"floor.jsonc",
"gather-elements.jsonc",
"gemm.jsonc",
Expand Down
Loading
Loading