From 46d3ee6abaacda363ffbc836617a78b4c7606ef4 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 1 Nov 2024 23:04:42 +0800 Subject: [PATCH] [js/webgpu] Optimize MatMul with M = 1 (#22577) ### Description BUG #22031 In the demucs model, there are lots of MatMul ops with shapes like below: `input[0]: [3448,1,512] | float32, input[1]: [512,1536] | float32, output[0]: [3448,1,1536] | float32` We can see that for this kind of shape, the batch size is a big value, but M = 1. Our current algorithm is based on [M, N] to partition tiles, which is not efficient for such kind of shapes. This PR reshapes the inputs to improve the matmul performance. Before: [3448,1,512] x [512,1536] = [3448,1,1536] After: [1, 3448, 512] x [512, 1536] = [1, 3448, 1536] , then the output can be reshaped to [3448, 1, 1536] The overall MatMul time in demucs model becomes 1778.45 ms from 4418.17 ms on my iGPUs. --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 15 ++++++- js/web/test/data/ops/matmul.jsonc | 50 +++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 7605e67c972b9..a645163d6dfa6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -201,6 +201,19 @@ export const matMul = (context: ComputeContext): void => { if (N < 8 && K < 8) { context.compute(createNaiveMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); + const M = outputShape[outputShape.length - 2]; + const batchA = ShapeUtil.size(context.inputs[0].dims.slice(0, -2)); + const batchB = ShapeUtil.size(context.inputs[1].dims.slice(0, -2)); + if (batchA !== 1 && M === 1 && batchB === 1) { + const reshapedA = context.inputs[0].reshape([1, batchA, K]); + const reshapedB = context.inputs[1].reshape([1, K, N]); + const matmulOutputShape = [1, batchA, N]; + const matmulInputs = [reshapedA, reshapedB]; + context.compute(createMatmulProgramInfo(matmulInputs, { activation: '' }, outputShape, matmulOutputShape), { + inputs: matmulInputs, + }); + } else { + context.compute(createMatmulProgramInfo(context.inputs, { activation: '' }, outputShape)); + } } }; diff --git a/js/web/test/data/ops/matmul.jsonc b/js/web/test/data/ops/matmul.jsonc index 2c2cf509d7e3e..ead6427350bca 100644 --- a/js/web/test/data/ops/matmul.jsonc +++ b/js/web/test/data/ops/matmul.jsonc @@ -95,6 +95,56 @@ } ] }, + { + "name": "multiplies 3D tensors with M = 1", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [6, 1, 8], + "type": "float32" + }, + { + "data": [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [1, 8, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [478, 514, 550, 2270, 2434, 2598, 1374, 1474, 1574, 590, 634, 678, 1486, 1594, 1702, 478, 514, 550], + "dims": [6, 1, 3], + "type": "float32" + } + ] + }, + { + "name": "multiplies 4D tensors with M = 1", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 17, 18, 19, 20, 21, 22, 23, 24, 9, 10, 11, 12, 13, 14, 15, 16, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [2, 3, 1, 8], + "type": "float32" + }, + { + "data": [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [1, 1, 8, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [478, 514, 550, 2270, 2434, 2598, 1374, 1474, 1574, 590, 634, 678, 1486, 1594, 1702, 478, 514, 550], + "dims": [2, 3, 1, 3], + "type": "float32" + } + ] + }, { "name": "multiplies 4D tensors", "inputs": [