Skip to content

Commit

Permalink
Nit
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Apr 26, 2024
1 parent c28b342 commit 790376f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
3 changes: 1 addition & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,7 @@ export const applyAttention =
const outputPresentKey = context.outputCount > 1;
const outputPresentValue = context.outputCount > 2;
const pastSequenceLength =
parameters
.pastSequenceLength; // (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
parameters.kvNumHeads != null || (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
// Concatinate pastKey and K to produce presentKey.
const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
Expand Down
18 changes: 11 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
let pastSequenceLength = 0;
let maxSequenceLength = 0;
const headSize = Math.floor(hiddenSize / attributes.numHeads);
if (pastKey && pastValue) {
const hasPastKey = pastKey && pastKey.dims.length != 0;
const hasPastValue = pastValue && pastValue.dims.length != 0;
if (hasPastKey && hasPastValue) {
if (pastKey.dims.length !== 4) {
throw new Error('Input "past_key" is expected to have 4 dimensions');
}
Expand All @@ -82,7 +84,7 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
pastSequenceLength = pastKey.dims[2];
maxSequenceLength = pastKey.dims[2];
}
} else if (pastKey || pastValue) {
} else if (hasPastKey || hasPastValue) {
throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
}

Expand Down Expand Up @@ -301,7 +303,8 @@ export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs):
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]});

const maybeExpandAndTransposeToBNSH =
(context: ComputeContext, input: TensorView, pastKV: TensorView, nReps: number, params: AttentionParameters) => {
(context: ComputeContext, input: TensorView, pastKV: TensorView|undefined, nReps: number,
params: AttentionParameters) => {
let reshapedInput = input;
let batchSize = params.batchSize;
let numHeads = params.kvNumHeads!;
Expand All @@ -321,7 +324,7 @@ const maybeExpandAndTransposeToBNSH =
// Concat here.
if (pastKV) {
// PastKV is BNSH, transpose to BSNH. TODO
//pastKV = context.compute(createTransposeProgramInfo(pastKV, weightTransposeAttribute.perm), {inputs:
// pastKV = context.compute(createTransposeProgramInfo(pastKV, weightTransposeAttribute.perm), {inputs:
// [pastKV], outputs: [-1]})[0];
reshapedInput = context.compute(
createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params),
Expand Down Expand Up @@ -356,8 +359,9 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti
0);

const nReps = Math.floor(attributes.numHeads / params.kvNumHeads!);
const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], context.inputs[3], params.nReps!, params);

const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], context.inputs[4], nReps, params);
const pastKeyInput = context.inputs[3] && context.inputs[3].dims.length != 0 ? context.inputs[3] : undefined;
const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKeyInput, params.nReps!, params);
const pastValueInput = context.inputs[4] && context.inputs[2].dims.length != 0 ? context.inputs[4] : undefined;
const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValueInput, nReps, params);
applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes);
};
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
"T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
GroupQueryAttentionTypeAndShapeInference(ctx, 3);
Expand Down

0 comments on commit 790376f

Please sign in to comment.