Skip to content

Commit

Permalink
[js/webgpu] Support GroupQueryAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Apr 9, 2024
1 parent e19c778 commit e6e2a93
Show file tree
Hide file tree
Showing 12 changed files with 472 additions and 16 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| GroupHeadAttention | com.microsoft(1+) | need implementing mask and past/present |
| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {fastGelu} from './ops/fast-gelu';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {groupQueryAttention, parseGroupQueryAttentionAttributes} from './ops/group-query-attention';
import {instanceNorm} from './ops/instance-norm';
import {layerNorm} from './ops/layer-norm';
import {matMul} from './ops/matmul';
Expand Down Expand Up @@ -85,6 +86,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]],
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
Expand Down
19 changes: 11 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface AttentionParameters {
headSize: number;
vHeadSize: number;
numHeads: number;
kvNumHeads?: number;
isUnidirectional: boolean;
pastPresentShareBuffer: boolean;
maskFilterValue: number;
Expand All @@ -57,6 +58,7 @@ export interface AttentionParameters {

export interface AttentionAttrs {
numHeads: number;
kvNumHeads?: number;
isUnidirectional: number;
maskFilterValue: number;
scale: number;
Expand Down Expand Up @@ -319,16 +321,16 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
{inputs: [input], outputs: []});
};

const computeAttentionProbs =
export const computeAttentionProbs =
(context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined,
parameters: AttentionParameters, attributes: AttentionAttrs) => {
parameters: AttentionParameters, scale: number) => {
const probsShape = [
parameters.batchSize, parameters.numHeads, parameters.sequenceLength,
parameters.kvSequenceLength + parameters.pastSequenceLength
];
// TODO: handle mask

const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
const alpha = scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : scale;
const components = getMaxComponents(parameters.headSize);
const vectorizedHeadSize = parameters.headSize / components;
const TILE_SIZE = 12;
Expand Down Expand Up @@ -417,19 +419,20 @@ const computeAttentionProbs =
return probs;
};

const computeVxAttentionScore =
(context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => {
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
export const computeVxAttentionScore =
(context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters, nReps = 1) => {
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize * nReps];
const TILE_SIZE = 12;
const dispatch = {
x: Math.ceil(params.vHeadSize / TILE_SIZE),
y: Math.ceil(params.sequenceLength / TILE_SIZE),
z: params.batchSize * params.numHeads
};

const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: params.vHiddenSize}
{type: DataType.uint32, data: params.vHiddenSize * nReps}
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -500,7 +503,7 @@ export const applyAttention =
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
_past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes);
const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes.scale);

computeVxAttentionScore(context, probs, v, parameters);
};
Expand Down
241 changes: 241 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {TensorView} from '../../tensor-view';
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext} from '../types';

import {AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat, computeAttentionProbs, computeVxAttentionScore} from './attention';
import {maybeTransposeToBNSHAndAddBias} from './multi-head-attentiion';
import {createTileProgramInfo} from './tile';
import {createTransposeProgramInfo, TransposeAttributes} from './transpose';

export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
const query = inputs[0];
const key = inputs[1];
const value = inputs[2];
const pastKey = inputs[3];
const pastValue = inputs[4];

// Abbreviation and Meanings:
// B: batch_size
// S: sequence_length (input sequence length of query)
// P: past_sequence_length (past sequence length of key or value)
// L: kv_sequence_length (input sequence length of key or value)
// M: max_sequence_length
// T: total_sequence_length = past_sequence_length + kv_sequence_length
// N: num_heads
// H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size
// H_v: v_head_size
// D_i: input hidden size
// D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size
// D_v: v_hidden_size = num_heads * v_head_size

// past_key : (B, N, S*, H)
// past_value : (B, N, S*, H)
// When no packing for q/k/v:
// query (Q) : (B, S, D)
// key (K) : (B, L, D) or (B, N, S*, H)
// value (V) : (B, L, D_v) or (B, N, S*, H)
// When packed kv is used:
// query (Q) : (B, S, D)
// key (K) : (B, L, N, 2, H)
// value (V) : None
// When packed qkv is used:
// query (Q) : (B, L, N, 3, H) or (B, S, 3*D)
// key (K) : None
// value (V) : None

if (query.dims.length !== 3 && query.dims.length !== 5) {
throw new Error('Input query is expected to have 3 or 5 dimensions');
}

const dmmhaPacking = false;
const batchSize = query.dims[0];
const sequenceLength = query.dims[1];
const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) :
attributes.numHeads * query.dims[4];
let kvSequenceLength = sequenceLength;

let pastSequenceLength = 0;
let maxSequenceLength = 0;
const headSize = Math.floor(hiddenSize / attributes.numHeads);
if (pastKey && pastValue) {
if (pastKey.dims.length !== 4) {
throw new Error('Input "past_key" is expected to have 4 dimensions');
}
if (pastValue.dims.length !== 4) {
throw new Error('Input "past_value" is expected to have 4 dimensions');
}
pastSequenceLength = pastKey.dims[2];
maxSequenceLength = pastKey.dims[2];
} else if (pastKey || pastValue) {
throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
}

let qkvFormat: AttentionQkvFormat;
if (key) {
if (query.dims.length !== 3) {
throw new Error('Input "query" is expected to have 3 dimensions when key is given');
}
if (key.dims.length < 3 || key.dims.length > 5) {
throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions');
}
if (query.dims[0] !== key.dims[0]) {
throw new Error('Input "query" and "key" shall have same dim 0 (batch size)');
}

if (key.dims.length === 3) {
if (query.dims[2] % key.dims[2] !== 0) {
throw new Error('Dimension of "query" should be multiple of "key"');
}
qkvFormat = AttentionQkvFormat.qkvBSNH;
kvSequenceLength = key.dims[1];
} else if (key.dims.length === 5) {
if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv');
}
if (value) {
throw new Error('Expect "value" be none when "key" has packed kv format.');
}
qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
kvSequenceLength = key.dims[1];
} else { // key_dims.size() == 4 (cross-attention with past_key)
if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
}

qkvFormat = AttentionQkvFormat.unknown;
kvSequenceLength = key.dims[2];
}
} else { // packed QKV
if (query.dims.length !== 3 && query.dims.length !== 5) {
throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty');
}
if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) {
throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');
}

qkvFormat = AttentionQkvFormat.qkvBSN3H;
}

const maskType: AttentionMaskType = AttentionMaskType.none;
let passPastInKv = false;
let vHiddenSize = hiddenSize;
if (value) {
if (value.dims.length !== 3 && value.dims.length !== 4) {
throw new Error('Input "value" is expected to have 3 or 4 dimensions');
}

if (query.dims[0] !== value.dims[0]) {
throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)');
}

if (value.dims.length === 3) {
if (kvSequenceLength !== value.dims[1]) {
throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)');
}
vHiddenSize = value.dims[2];
} else {
if (kvSequenceLength !== value.dims[2]) {
throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)');
}
vHiddenSize = value.dims[1] * value.dims[3];
passPastInKv = true;
}
}

const totalSequenceLength = pastSequenceLength + kvSequenceLength;
const broadcastResPosBias = false;

if (pastKey) {
throw new Error('pastKey is not supported');
}
if (pastValue) {
throw new Error('pastValue is not supported');
}

const kvNumHeads = attributes.kvNumHeads;
return {
batchSize,
sequenceLength,
pastSequenceLength,
kvSequenceLength,
totalSequenceLength,
maxSequenceLength,
inputHiddenSize: 0,
hiddenSize,
vHiddenSize,
headSize,
vHeadSize: Math.floor(vHiddenSize / kvNumHeads!),
numHeads: attributes.numHeads,
kvNumHeads,
isUnidirectional: false,
pastPresentShareBuffer: false,
maskFilterValue: attributes.maskFilterValue,
maskType,
scale: attributes.scale,
broadcastResPosBias,
passPastInKv,
qkvFormat,
};
};

export const applyAttention =
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
_past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, nReps: number|undefined) => {
const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, 1.0);

computeVxAttentionScore(context, probs, v, parameters, nReps);
};

export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
createAttributeWithCacheKey({...attributes});

const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]});

const maybeExpandAndTransposeToBNSH =
(context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number,
input: TensorView, nReps: number) => {
let reshapedInput = input;
if (input.dims.length === 3) {
reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
}
if (nReps !== 1) {
reshapedInput = context.compute(
createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), {inputs: [reshapedInput], outputs: [-1]})[0];

reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads * nReps, headSize]);
}

return context.compute(
createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm),
{inputs: [reshapedInput], outputs: [-1]})[0];
};

export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => {
const params = validateInputs(context.inputs, attributes);

if (context.inputs[0].dims.length === 5) {
throw new Error('Packed QKV is not implemented');
}

if (context.inputs[1]?.dims.length === 5) {
throw new Error('Packed KV is not implemented');
}

const Q = maybeTransposeToBNSHAndAddBias(
context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0],
context.inputs[3], 0);

const nReps = Math.floor(attributes.numHeads / params.kvNumHeads!);
const K = maybeExpandAndTransposeToBNSH(
context, params.batchSize, params.kvNumHeads!, params.kvSequenceLength, params.vHeadSize, context.inputs[1],
nReps);

const V = maybeExpandAndTransposeToBNSH(
context, params.batchSize, params.kvNumHeads!, params.kvSequenceLength, params.vHeadSize, context.inputs[2],
nReps);
applyAttention(context, Q, K, V, context.inputs[4], undefined, undefined, undefined, undefined, params, nReps);
};
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ const addBiasTranspose =
{inputs: [qkv, bias], outputs: [-1]})[0];
};

const maybeTransposeToBNSHAndAddBias =
export const maybeTransposeToBNSHAndAddBias =
(context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number,
input: TensorView, bias?: TensorView, biasOffset?: number) => {
// const newDims = [];
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/tile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ const getOutputShape = (inputShape: readonly number[], repeats: readonly number[
return outputShape;
};

export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => {
export const createTileProgramInfo = (inputs: readonly TensorView[], shape?: number[]): ProgramInfo => {
const inputShape = inputs[0].dims;
const repeats: readonly number[] = getRepeats(inputs[1]);
const repeats: readonly number[] = shape == null ? getRepeats(inputs[1]) : shape;
const outputShape = getOutputShape(inputShape, repeats);
const outputSize = ShapeUtil.size(outputShape);

Expand Down
Loading

0 comments on commit e6e2a93

Please sign in to comment.