Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Feb 15, 2022
1 parent c243e11 commit c685ff5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/flags_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ENV.registerFlag('WEBGPU_CPU_FORWARD', () => true);
/**
* Thread register block size for matmul kernel.
*/
ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 4);
ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 2);

/**
* Whether to use conv2d_naive which directly implement the conv2d logic rather
Expand Down
28 changes: 10 additions & 18 deletions tfjs-backend-webgpu/src/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ export function makeMatMulPackedSource(
let localRow = i32(localId.y);
let localCol = i32(localId.x);
let newGlobalRow = i32(workgroupId.y) * ${tileAOuter};
let newGlobalCol = i32(workgroupId.x) * ${tileBOuter};
let globalRow = i32(workgroupId.y) * ${tileAOuter};
let globalCol = i32(workgroupId.x) * ${tileBOuter};
let numTiles = (uniforms.dimInner - 1) / ${tileInner} + 1;
Expand All @@ -67,7 +67,7 @@ export function makeMatMulPackedSource(
for (var inputCol = localCol; inputCol < ${
tileInner}; inputCol = inputCol + ${workGroupSize[0]}) {
mm_Asub[inputRow][inputCol] = mm_readA(
newGlobalRow + inputRow,
globalRow + inputRow,
t * ${tileInner} + inputCol, globalId);
}
}
Expand All @@ -78,7 +78,7 @@ export function makeMatMulPackedSource(
tileBOuter}; inputCol = inputCol + ${workGroupSize[0]}) {
mm_Bsub[inputRow][inputCol] = mm_readB(
t * ${tileInner} + inputRow,
newGlobalCol + inputCol, globalId);
globalCol + inputCol, globalId);
}
}
Expand All @@ -95,7 +95,8 @@ export function makeMatMulPackedSource(
ACached = mm_Asub[localRow + innerRow * ${workGroupSize[1]}][k];
for (var innerCol = 0; innerCol < ${
workPerThread[0]}; innerCol = innerCol + 1) {
acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol];
acc[innerRow][innerCol] = acc[innerRow][innerCol] +
ACached * BCached[innerCol];
}
}
}
Expand All @@ -107,13 +108,10 @@ export function makeMatMulPackedSource(
workPerThread[1]}; innerRow = innerRow + 1) {
for (var innerCol = 0; innerCol < ${
workPerThread[0]}; innerCol = innerCol + 1) {
let gRow = newGlobalRow + localRow + innerRow * ${workGroupSize[1]};
let gCol = newGlobalCol + localCol + innerCol * ${workGroupSize[0]};
if (gCol < uniforms.dimBOuter &&
gRow < uniforms.dimAOuter) {
mm_write(gRow,
gCol,
acc[innerRow][innerCol], globalId);
let gRow = globalRow + localRow + innerRow * ${workGroupSize[1]};
let gCol = globalCol + localCol + innerCol * ${workGroupSize[0]};
if (gCol < uniforms.dimBOuter && gRow < uniforms.dimAOuter) {
mm_write(gRow, gCol, acc[innerRow][innerCol], globalId);
}
}
}
Expand Down Expand Up @@ -195,7 +193,6 @@ export class MatMulPackedProgram implements WebGPUProgram {
this.outputShape = outputShape;
this.dispatchLayout = {x: [2], y: [1], z: [0]};
const dimInner = transposeA ? aShape[1] : aShape[2];
workPerThread = 2;
this.workGroupSize =
computeWorkGroupSizeForMatMul(outputShape[1], dimInner, outputShape[2]);
if (outputShape[1] === 1 || outputShape[2] === 1) {
Expand Down Expand Up @@ -251,11 +248,6 @@ export class MatMulPackedProgram implements WebGPUProgram {
if (this.outputShape[1] === 1) {
tileInner *= 4;
}
util.assert(
tileInner % this.workGroupSize[0] === 0 &&
tileInner % this.workGroupSize[1] === 0,
() => `tileInner must be multiple of workgroupsize.x ` +
`and workgroupsize.y`);
const tileSizeA = [tileAOuter, tileInner];
const tileSizeB = [tileInner, tileBOuter];

Expand Down

0 comments on commit c685ff5

Please sign in to comment.