From c685ff51b3f0b0f3738dab96310e53324d7c774c Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Tue, 15 Feb 2022 15:35:46 +0800 Subject: [PATCH] nits --- tfjs-backend-webgpu/src/flags_webgpu.ts | 2 +- .../src/matmul_packed_webgpu.ts | 28 +++++++------------ 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/tfjs-backend-webgpu/src/flags_webgpu.ts b/tfjs-backend-webgpu/src/flags_webgpu.ts index 937761df4eb..62746a9a0e9 100644 --- a/tfjs-backend-webgpu/src/flags_webgpu.ts +++ b/tfjs-backend-webgpu/src/flags_webgpu.ts @@ -31,7 +31,7 @@ ENV.registerFlag('WEBGPU_CPU_FORWARD', () => true); /** * Thread register block size for matmul kernel. */ -ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 4); +ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 2); /** * Whether to use conv2d_naive which directly implement the conv2d logic rather diff --git a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts index dc23d74802e..391365c8838 100644 --- a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts @@ -41,8 +41,8 @@ export function makeMatMulPackedSource( let localRow = i32(localId.y); let localCol = i32(localId.x); - let newGlobalRow = i32(workgroupId.y) * ${tileAOuter}; - let newGlobalCol = i32(workgroupId.x) * ${tileBOuter}; + let globalRow = i32(workgroupId.y) * ${tileAOuter}; + let globalCol = i32(workgroupId.x) * ${tileBOuter}; let numTiles = (uniforms.dimInner - 1) / ${tileInner} + 1; @@ -67,7 +67,7 @@ export function makeMatMulPackedSource( for (var inputCol = localCol; inputCol < ${ tileInner}; inputCol = inputCol + ${workGroupSize[0]}) { mm_Asub[inputRow][inputCol] = mm_readA( - newGlobalRow + inputRow, + globalRow + inputRow, t * ${tileInner} + inputCol, globalId); } } @@ -78,7 +78,7 @@ export function makeMatMulPackedSource( tileBOuter}; inputCol = inputCol + ${workGroupSize[0]}) { mm_Bsub[inputRow][inputCol] = mm_readB( t * ${tileInner} + inputRow, - newGlobalCol + inputCol, globalId); + globalCol + inputCol, globalId); } } @@ -95,7 +95,8 @@ export function makeMatMulPackedSource( ACached = mm_Asub[localRow + innerRow * ${workGroupSize[1]}][k]; for (var innerCol = 0; innerCol < ${ workPerThread[0]}; innerCol = innerCol + 1) { - acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol]; + acc[innerRow][innerCol] = acc[innerRow][innerCol] + + ACached * BCached[innerCol]; } } } @@ -107,13 +108,10 @@ export function makeMatMulPackedSource( workPerThread[1]}; innerRow = innerRow + 1) { for (var innerCol = 0; innerCol < ${ workPerThread[0]}; innerCol = innerCol + 1) { - let gRow = newGlobalRow + localRow + innerRow * ${workGroupSize[1]}; - let gCol = newGlobalCol + localCol + innerCol * ${workGroupSize[0]}; - if (gCol < uniforms.dimBOuter && - gRow < uniforms.dimAOuter) { - mm_write(gRow, - gCol, - acc[innerRow][innerCol], globalId); + let gRow = globalRow + localRow + innerRow * ${workGroupSize[1]}; + let gCol = globalCol + localCol + innerCol * ${workGroupSize[0]}; + if (gCol < uniforms.dimBOuter && gRow < uniforms.dimAOuter) { + mm_write(gRow, gCol, acc[innerRow][innerCol], globalId); } } } @@ -195,7 +193,6 @@ export class MatMulPackedProgram implements WebGPUProgram { this.outputShape = outputShape; this.dispatchLayout = {x: [2], y: [1], z: [0]}; const dimInner = transposeA ? aShape[1] : aShape[2]; - workPerThread = 2; this.workGroupSize = computeWorkGroupSizeForMatMul(outputShape[1], dimInner, outputShape[2]); if (outputShape[1] === 1 || outputShape[2] === 1) { @@ -251,11 +248,6 @@ export class MatMulPackedProgram implements WebGPUProgram { if (this.outputShape[1] === 1) { tileInner *= 4; } - util.assert( - tileInner % this.workGroupSize[0] === 0 && - tileInner % this.workGroupSize[1] === 0, - () => `tileInner must be multiple of workgroupsize.x ` + - `and workgroupsize.y`); const tileSizeA = [tileAOuter, tileInner]; const tileSizeB = [tileInner, tileBOuter];