From 790376fe445b1df49bd09844df74e4ce8c730a86 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 26 Apr 2024 12:24:32 +0800 Subject: [PATCH] Nit --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 3 +-- .../jsep/webgpu/ops/group-query-attention.ts | 18 +++++++++++------- .../core/graph/contrib_ops/bert_defs.cc | 2 +- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index e41c2d8557af5..21d1cf7ce6504 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -544,8 +544,7 @@ export const applyAttention = const outputPresentKey = context.outputCount > 1; const outputPresentValue = context.outputCount > 2; const pastSequenceLength = - parameters - .pastSequenceLength; // (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0; + parameters.kvNumHeads != null || (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; // Concatinate pastKey and K to produce presentKey. const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 99dd7a7b7ea43..9be7722560357 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -65,7 +65,9 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent let pastSequenceLength = 0; let maxSequenceLength = 0; const headSize = Math.floor(hiddenSize / attributes.numHeads); - if (pastKey && pastValue) { + const hasPastKey = pastKey && pastKey.dims.length != 0; + const hasPastValue = pastValue && pastValue.dims.length != 0; + if (hasPastKey && hasPastValue) { if (pastKey.dims.length !== 4) { throw new Error('Input "past_key" is expected to have 4 dimensions'); } @@ -82,7 +84,7 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent pastSequenceLength = pastKey.dims[2]; maxSequenceLength = pastKey.dims[2]; } - } else if (pastKey || pastValue) { + } else if (hasPastKey || hasPastValue) { throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); } @@ -301,7 +303,8 @@ export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); const maybeExpandAndTransposeToBNSH = - (context: ComputeContext, input: TensorView, pastKV: TensorView, nReps: number, params: AttentionParameters) => { + (context: ComputeContext, input: TensorView, pastKV: TensorView|undefined, nReps: number, + params: AttentionParameters) => { let reshapedInput = input; let batchSize = params.batchSize; let numHeads = params.kvNumHeads!; @@ -321,7 +324,7 @@ const maybeExpandAndTransposeToBNSH = // Concat here. if (pastKV) { // PastKV is BNSH, transpose to BSNH. TODO - //pastKV = context.compute(createTransposeProgramInfo(pastKV, weightTransposeAttribute.perm), {inputs: + // pastKV = context.compute(createTransposeProgramInfo(pastKV, weightTransposeAttribute.perm), {inputs: // [pastKV], outputs: [-1]})[0]; reshapedInput = context.compute( createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), @@ -356,8 +359,9 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti 0); const nReps = Math.floor(attributes.numHeads / params.kvNumHeads!); - const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], context.inputs[3], params.nReps!, params); - - const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], context.inputs[4], nReps, params); + const pastKeyInput = context.inputs[3] && context.inputs[3].dims.length != 0 ? context.inputs[3] : undefined; + const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKeyInput, params.nReps!, params); + const pastValueInput = context.inputs[4] && context.inputs[2].dims.length != 0 ? context.inputs[4] : undefined; + const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValueInput, nReps, params); applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes); }; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 20b2f9ebfe3a5..80462646b93f7 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1099,7 +1099,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T", OpSchema::Optional) - .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { GroupQueryAttentionTypeAndShapeInference(ctx, 3);