From 0409c639f77926b5966b003950c3668247cf6692 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 15 Oct 2024 06:43:14 +0800 Subject: [PATCH] [js/webgpu] Optimize MultiHeadAttention|Transpose (#22420) ### Description With this optimization, 96 MultiHeadAttention|Transpose ops in phi3 disappear. Phi3 becomes 113 tokens from 107 tokens on my dGPUs. The optimization mainly skips the transpose op if one of the transposed dims is 1. Reshape is enough. --- js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 0949d65174b41..1a31253905694 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -338,6 +338,9 @@ export const maybeTransposeToBNSHAndAddBias = ( if (input.dims.length === 3) { reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); } + if (numHeads === 1 || sequenceLength === 1) { + return reshapedInput; + } return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], @@ -356,6 +359,9 @@ export const maybeTransposeToBNSHAndAddBias = ( biasOffset!, ); reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + if (numHeads === 1 || sequenceLength === 1) { + return reshapedInput; + } return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1],