Skip to content

Commit

Permalink
[Nijs/webgpu] Support GroupQueryAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Apr 9, 2024
1 parent e19c778 commit cc13e08
Show file tree
Hide file tree
Showing 12 changed files with 571 additions and 18 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| GroupHeadAttention | com.microsoft(1+) | need implementing mask and past/present |
| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {fastGelu} from './ops/fast-gelu';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {groupQueryAttention, parseGroupQueryAttentionAttributes} from './ops/group-query-attention';
import {instanceNorm} from './ops/instance-norm';
import {layerNorm} from './ops/layer-norm';
import {matMul} from './ops/matmul';
Expand Down Expand Up @@ -85,6 +86,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]],
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
Expand Down
19 changes: 11 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface AttentionParameters {
headSize: number;
vHeadSize: number;
numHeads: number;
kvNumHeads?: number
isUnidirectional: boolean;
pastPresentShareBuffer: boolean;
maskFilterValue: number;
Expand All @@ -57,6 +58,7 @@ export interface AttentionParameters {

export interface AttentionAttrs {
numHeads: number;
kvNumHeads?: number
isUnidirectional: number;
maskFilterValue: number;
scale: number;
Expand Down Expand Up @@ -319,16 +321,16 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
{inputs: [input], outputs: []});
};

const computeAttentionProbs =
export const computeAttentionProbs =
(context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined,
parameters: AttentionParameters, attributes: AttentionAttrs) => {
parameters: AttentionParameters, scale: number) => {
const probsShape = [
parameters.batchSize, parameters.numHeads, parameters.sequenceLength,
parameters.kvSequenceLength + parameters.pastSequenceLength
];
// TODO: handle mask

const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
const alpha = scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : scale;
const components = getMaxComponents(parameters.headSize);
const vectorizedHeadSize = parameters.headSize / components;
const TILE_SIZE = 12;
Expand Down Expand Up @@ -417,19 +419,20 @@ const computeAttentionProbs =
return probs;
};

const computeVxAttentionScore =
(context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => {
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
export const computeVxAttentionScore =
(context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters, nReps: number = 1) => {
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize * nReps];
const TILE_SIZE = 12;
const dispatch = {
x: Math.ceil(params.vHeadSize / TILE_SIZE),
y: Math.ceil(params.sequenceLength / TILE_SIZE),
z: params.batchSize * params.numHeads
};

const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: params.vHiddenSize}
{type: DataType.uint32, data: params.vHiddenSize* nReps}
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -500,7 +503,7 @@ export const applyAttention =
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
_past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes);
const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes.scale);

computeVxAttentionScore(context, probs, v, parameters);
};
Expand Down
Loading

0 comments on commit cc13e08

Please sign in to comment.