From a1e5aeabc24271ac4f9eefb434b2e6076367693a Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 12 Aug 2022 16:12:05 +0800 Subject: [PATCH 1/3] webgpu: Merge MatMulPackedProgram and MatMulPackedVec4Program (#6688) * webgpu: Merge MatMulPackedProgram and MatMulPackedVec4Program This PR merges MatMulPackedVec4Program to MatMulPackedProgram and refactors MatMulSplitKProgram. --- tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts | 12 +- .../src/conv_backprop_mm_webgpu.ts | 23 +- tfjs-backend-webgpu/src/flags_webgpu.ts | 5 - .../src/kernels/BatchMatMul_impl.ts | 20 +- .../src/matmul_packed_vec4_webgpu.ts | 263 ------------------ .../src/matmul_packed_webgpu.ts | 248 ++++++++++++++--- .../src/matmul_splitK_webgpu.ts | 165 ++++------- tfjs-backend-webgpu/src/matmul_test.ts | 169 ++++++++++- tfjs-backend-webgpu/src/webgpu_util.ts | 51 ++-- 9 files changed, 458 insertions(+), 498 deletions(-) delete mode 100644 tfjs-backend-webgpu/src/matmul_packed_vec4_webgpu.ts diff --git a/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts b/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts index b835f99a1d4..077bece398a 100644 --- a/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts @@ -18,8 +18,7 @@ import {backend_util} from '@tensorflow/tfjs-core'; import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; -import {makeMatMulPackedVec4Source} from './matmul_packed_vec4_webgpu'; -import {makeMatMulPackedSource} from './matmul_packed_webgpu'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; import {WebGPUProgram} from './webgpu_program'; import {computeDispatch, computeWorkGroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util'; @@ -251,14 +250,13 @@ export class Conv2DMMProgram implements WebGPUProgram { getUserCode(): string { const matMulSource = this.isVec4 ? makeMatMulPackedVec4Source( - this.elementsPerThread, this.tileAOuter, this.tileBOuter, - this.tileInner, this.innerElementSize, !this.isChannelsLast) : + this.elementsPerThread, this.workGroupSize, !this.isChannelsLast, + this.tileInner) : makeMatMulPackedSource( this.elementsPerThread, this.workGroupSize, !this.isChannelsLast, this.tileInner); - const elementsSize = this.isVec4 ? - [this.isChannelsLast ? this.innerElementSize : 4, 4, 4] : - [1, 1, 1]; + const elementsSize = + this.isVec4 ? [this.innerElementSize, 4, 4] : [1, 1, 1]; const userCode = ` ${ conv2dCommonSnippet( diff --git a/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts b/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts index e081fa951de..65612bb828a 100644 --- a/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts @@ -17,8 +17,7 @@ import {backend_util, util} from '@tensorflow/tfjs-core'; import {typeSnippet} from './activation_util'; -import {makeMatMulPackedVec4Source} from './matmul_packed_vec4_webgpu'; -import {makeMatMulPackedSource} from './matmul_packed_webgpu'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; import {WebGPUProgram} from './webgpu_program'; import {computeDispatch, computeWorkGroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util'; @@ -123,10 +122,6 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram { 'filterDims : vec2, pads : vec2, stride : vec2, outBackprop : vec4, dimAOuter : i32, dimBOuter : i32, dimInner : i32,'; workGroupSize: [number, number, number]; elementsPerThread: [number, number, number]; - tileAOuter: number; - tileBOuter: number; - tileInner: number; - innerElementSize: number; isVec4?: boolean; constructor(convInfo: backend_util.Conv2DInfo) { @@ -148,24 +143,16 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram { this.elementsPerThread); if (this.isVec4) { - this.innerElementSize = 4; this.variableTypes = ['vec4', 'f32']; - } else { - this.innerElementSize = this.elementsPerThread[0]; } - this.tileAOuter = this.workGroupSize[1] * this.elementsPerThread[1]; - this.tileBOuter = this.workGroupSize[0] * this.elementsPerThread[0]; - this.tileInner = Math.max( - this.workGroupSize[0] * this.innerElementSize, this.workGroupSize[1]); - this.shaderKey = `conv2DDerInputMM_${this.isVec4}_${ - this.elementsPerThread}_${this.innerElementSize}`; + + this.shaderKey = + `conv2DDerInputMM_${this.isVec4}_${this.elementsPerThread}`; } getUserCode(): string { const matMulSource = this.isVec4 ? - makeMatMulPackedVec4Source( - this.elementsPerThread, this.tileAOuter, this.tileBOuter, - this.tileInner, this.innerElementSize) : + makeMatMulPackedVec4Source(this.elementsPerThread, this.workGroupSize) : makeMatMulPackedSource(this.elementsPerThread, this.workGroupSize); const userCode = ` ${conv2dTransposeCommonSnippet(this.isVec4 ? 4 : 1)} diff --git a/tfjs-backend-webgpu/src/flags_webgpu.ts b/tfjs-backend-webgpu/src/flags_webgpu.ts index 488192f0b5e..2df5d8cd5d8 100644 --- a/tfjs-backend-webgpu/src/flags_webgpu.ts +++ b/tfjs-backend-webgpu/src/flags_webgpu.ts @@ -28,11 +28,6 @@ ENV.registerFlag('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', () => 15); */ ENV.registerFlag('WEBGPU_CPU_FORWARD', () => true); -/** - * Thread register block size for matmul kernel. - */ -ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 4); - /** * This flag is used to test different types of matmul programs. * diff --git a/tfjs-backend-webgpu/src/kernels/BatchMatMul_impl.ts b/tfjs-backend-webgpu/src/kernels/BatchMatMul_impl.ts index e079dae517e..eca42437746 100644 --- a/tfjs-backend-webgpu/src/kernels/BatchMatMul_impl.ts +++ b/tfjs-backend-webgpu/src/kernels/BatchMatMul_impl.ts @@ -18,7 +18,6 @@ import {backend_util, broadcast_util, env, TensorInfo, util} from '@tensorflow/tfjs-core'; import {WebGPUBackend} from '../backend_webgpu'; -import {MatMulPackedVec4Program} from '../matmul_packed_vec4_webgpu'; import {MatMulPackedProgram} from '../matmul_packed_webgpu'; import {MatMulReduceProgram} from '../matmul_reduce_webgpu'; import {MatMulSmallOutputSizeProgram} from '../matmul_small_output_size_webgpu'; @@ -93,9 +92,6 @@ export function batchMatMulImpl({ const batchDim = Math.max(batchDimA, batchDimB); const batchAEqualOne = batchDimA === 1; const batchBEqualOne = batchDimB === 1; - const useVec4 = ((innerShapeA % 4 === 0 && !transposeA) || - (outerShapeA % 4 === 0 && transposeA)) && - outerShapeB % 4 === 0 && !transposeB; const inputs: TensorInfo[] = [a3d, b3d]; const dimensions = [ @@ -133,22 +129,12 @@ export function batchMatMulImpl({ (outerShapeB <= 16 && (outerShapeA <= 512 || innerShapeA >= 2 * outerShapeA))) { matmulProgramType = MatMulProgramType.MatMulSmallOutputSizeProgram; - } else if (useVec4) { - // TODO: Currently we need to make sure that innerShapeA and outerShapeB - // are divisible by 4 since we use vec4 to get data. In future, we can - // remove this limitation by insert 0 to pack data. - matmulProgramType = MatMulProgramType.MatMulPackedVec4Program; } else { matmulProgramType = MatMulProgramType.MatMulPackedProgram; } } switch (matmulProgramType) { - case MatMulProgramType.MatMulPackedVec4Program: - program = new MatMulPackedVec4Program( - a3dShape, outputShape, batchAEqualOne, batchBEqualOne, transposeA, - bias, activation, preluActivationWeights); - break; case MatMulProgramType.MatMulReduceProgram: program = new MatMulReduceProgram( outputShape, batchAEqualOne, batchBEqualOne, transposeA, transposeB, @@ -199,10 +185,8 @@ export function batchMatMulImpl({ break; case MatMulProgramType.MatMulPackedProgram: program = new MatMulPackedProgram( - a3dShape, outputShape, - env().get('WEBGPU_MATMUL_WORK_PER_THREAD') as number, batchAEqualOne, - batchBEqualOne, transposeA, transposeB, bias, activation, - preluActivationWeights); + a3dShape, outputShape, batchAEqualOne, batchBEqualOne, transposeA, + transposeB, bias, activation, preluActivationWeights); break; default: throw new Error(`Unsupported MatMulProgramType ${matmulProgramType}.`); diff --git a/tfjs-backend-webgpu/src/matmul_packed_vec4_webgpu.ts b/tfjs-backend-webgpu/src/matmul_packed_vec4_webgpu.ts deleted file mode 100644 index 470e7b344f8..00000000000 --- a/tfjs-backend-webgpu/src/matmul_packed_vec4_webgpu.ts +++ /dev/null @@ -1,263 +0,0 @@ -/** - * @license - * Copyright 2020 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core'; - -import {activationFnSnippet} from './activation_util'; -import {matMulReadWriteFnSource} from './matmul_packed_webgpu'; -import {WebGPUProgram} from './webgpu_program'; -import {computeDispatch} from './webgpu_util'; - -const writeDataToSubASnippet = - (transpose: boolean, innerAElementSize: number) => { - if (transpose) { - return ` - mm_Asub[inputRow][inputCol] = mm_readA(batch, - t * TileInner + inputRow, - globalRowStart / ${innerAElementSize} + inputCol); - `; - - } else { - return ` - mm_Asub[inputRow][inputCol] = mm_readA(batch, - globalRow + innerRow, - t * TileInner / ${innerAElementSize} + inputCol); - `; - } - }; - -const calculateResultSnippet = - (transposeA: boolean, innerElementSize: number) => { - if (transposeA) { - return ` - let ACached0 = mm_Asub[k * InnerElementSize][localRow]; - let ACached1 = mm_Asub[k * InnerElementSize + 1][localRow]; - let ACached2 = mm_Asub[k * InnerElementSize + 2][localRow]; - ${ - innerElementSize === 3 ? - '' : - 'let ACached3 = mm_Asub[k * InnerElementSize + 3][localRow];'} - for (var i = 0; i < RowPerThread; i = i + 1) { - acc[i] = BCached[0] * ACached0[i] + acc[i]; - acc[i] = BCached[1] * ACached1[i] + acc[i]; - acc[i] = BCached[2] * ACached2[i] + acc[i]; - ${ - innerElementSize === 3 ? - '' : - 'acc[i] = BCached[3] * ACached3[i] + acc[i];'} - }`; - } else { - return ` - for (var i = 0; i < RowPerThread; i = i + 1) { - let ACached = mm_Asub[tileRow + i][k]; - acc[i] = BCached[0] * ACached.x + acc[i]; - acc[i] = BCached[1] * ACached.y + acc[i]; - acc[i] = BCached[2] * ACached.z + acc[i]; - ${ - innerElementSize === 3 ? - '' : - 'acc[i] = BCached[3] * ACached.w + acc[i];'} - }`; - } - }; - -export function makeMatMulPackedVec4Source( - workPerThread: number[], tileAOuter: number, tileBOuter: number, - tileInner: number, innerElementSize = 4, transposeA = false): string { - const tileAWidth = transposeA ? tileAOuter : tileInner; - const tileAHight = transposeA ? tileInner : tileAOuter; - const innerAElementSize = transposeA ? workPerThread[1] : innerElementSize; - // For simplicity, if transposeA is true, tileAOuter must be equal to - // tileBOuter. - util.assert( - ((transposeA && tileAOuter === tileBOuter) || - (tileInner % 4 === 0 || tileInner % 3 === 0)) && - workPerThread[0] === 4 && - (innerElementSize === 3 || innerElementSize === 4), - () => `tileInner ${tileInner} must be divisible by 4|3. ColPerThread ${ - workPerThread[0]} must be 4. - innerElementSize ${innerElementSize} must be 3|4.`); - return ` - var mm_Asub : array, ${ - tileAWidth / innerAElementSize}>, ${tileAHight}>; - var mm_Bsub : array, ${ - tileBOuter / workPerThread[0]}>, ${tileInner}>; - - const RowPerThread = ${workPerThread[1]}; - const ColPerThread = ${workPerThread[0]}; - const InnerElementSize = ${innerElementSize}; - const TileInner = ${tileInner}; - - @compute @workgroup_size(workGroupSizeX, workGroupSizeY, workGroupSizeZ) - fn main(@builtin(local_invocation_id) LocalId : vec3, - @builtin(global_invocation_id) GlobalId : vec3, - @builtin(num_workgroups) NumWorkgroups: vec3, - @builtin(workgroup_id) workgroupId: vec3) { - localId = LocalId; - globalId = GlobalId; - numWorkgroups = NumWorkgroups; - - let localRow = i32(localId.y); - let tileRow = ${tileAOuter === 1 ? '0' : 'localRow * RowPerThread'}; - let tileCol = i32(localId.x); - - let globalRow = ${ - tileAOuter === 1 ? '0' : 'i32(globalId.y) * RowPerThread'}; - let globalCol = i32(globalId.x); - let batch = i32(globalId.z); - let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - - let numTiles = (uniforms.dimInner - 1) / TileInner + 1; - - var acc: array, RowPerThread>; - var BCached : array, 4>; - - // Loop over shared dimension. - let RowPerThreadB = TileInner / i32(workGroupSizeY); - let tileRowB = localRow * RowPerThreadB; - for (var t = 0; t < numTiles; t = t + 1) { - // Load one tile of A into local memory. - for (var innerRow = 0; innerRow < RowPerThread; innerRow = innerRow + 1) { - let inputRow = tileRow + innerRow; - let inputCol = tileCol; - ${writeDataToSubASnippet(transposeA, innerAElementSize)} - } - - // Load one tile of B into local memory. - for (var innerRow = 0; innerRow < RowPerThreadB; innerRow = innerRow + 1) { - let inputRow = tileRowB + innerRow; - let inputCol = tileCol; - mm_Bsub[inputRow][inputCol] = mm_readB(batch, t * TileInner + inputRow, globalCol); - } - - workgroupBarrier(); - - // Compute acc values for a single thread. - for (var k = 0; k < TileInner / InnerElementSize; k = k + 1) { - BCached[0] = mm_Bsub[k * InnerElementSize][tileCol]; - BCached[1] = mm_Bsub[k * InnerElementSize + 1][tileCol]; - BCached[2] = mm_Bsub[k * InnerElementSize + 2][tileCol]; - ${ - innerElementSize === 3 ? - '' : - 'BCached[3] = mm_Bsub[k * InnerElementSize + 3][tileCol];'} - - ${calculateResultSnippet(transposeA, innerElementSize)} - } - - workgroupBarrier(); - } - - for (var innerRow = 0; innerRow < RowPerThread; innerRow = innerRow + 1) { - mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]); - } - }`; -} - -export class MatMulPackedVec4Program implements WebGPUProgram { - outputShape: number[]; - shaderKey: string; - dispatchLayout: {x: number[], y: number[], z: number[]}; - dispatch: [number, number, number]; - variableNames = ['A', 'B']; - uniforms = `dimAOuter : i32, dimBOuter : i32, dimInner : i32,`; - workGroupSize: [number, number, number] = [8, 8, 1]; - elementsPerThread: [number, number, number]; - isVec4 = true; - aShape: [number, number, number]; - addBias: boolean; - activation: backend_util.Activation; - hasPreluActivationWeights: boolean; - tileAOuter: number; - tileBOuter: number; - tileInner: number; - fitAOuter: boolean; - fitBOuter: boolean; - fitInner: boolean; - batchAEqualOne: boolean; - batchBEqualOne: boolean; - transposeA: boolean; - - constructor( - aShape: [number, number, number], outputShape: [number, number, number], - batchAEqualOne: boolean, batchBEqualOne: boolean, transposeA = false, - bias: TensorInfo = null, activation: backend_util.Activation = null, - preluActivationWeights: TensorInfo = null) { - this.outputShape = outputShape; - this.dispatchLayout = {x: [2], y: [1], z: [0]}; - // The first element in elementsPerThread must be 4. - if (outputShape[1] === 1 && !transposeA) { - this.elementsPerThread = [4, 1, 1]; - } else { - this.elementsPerThread = [4, 4, 1]; - } - this.dispatch = computeDispatch( - this.dispatchLayout, this.outputShape, this.workGroupSize, - this.elementsPerThread); - - const addBias = bias != null; - const hasPreluActivationWeights = preluActivationWeights != null; - if (addBias) { - this.variableNames.push('bias'); - } - - if (hasPreluActivationWeights) { - this.variableNames.push('preluActivationWeights'); - } - - this.tileAOuter = outputShape[1] === 1 && !transposeA ? - 1 : - this.workGroupSize[1] * this.elementsPerThread[1]; - this.tileBOuter = this.workGroupSize[0] * this.elementsPerThread[0]; - this.tileInner = this.tileBOuter; - - this.aShape = aShape; - this.addBias = addBias; - this.activation = activation; - this.hasPreluActivationWeights = hasPreluActivationWeights; - this.batchAEqualOne = batchAEqualOne; - this.batchBEqualOne = batchBEqualOne; - this.transposeA = transposeA; - - const dimInner = transposeA ? aShape[1] : aShape[2]; - this.fitAOuter = outputShape[1] % this.tileAOuter === 0; - this.fitBOuter = outputShape[2] % this.tileBOuter === 0; - this.fitInner = dimInner % this.tileInner === 0; - - this.shaderKey = `matMulPackedVec4_${this.activation}_${this.fitAOuter}_${ - this.fitBOuter}_${this.fitInner}_${this.elementsPerThread}_${ - this.batchAEqualOne}_${this.batchBEqualOne}_${this.transposeA}`; - } - - getUserCode(): string { - const userCode = ` - ${ - activationFnSnippet( - this.activation, this.hasPreluActivationWeights, true)} - ${ - matMulReadWriteFnSource( - this.addBias, this.activation, this.batchAEqualOne, - this.batchBEqualOne, false, false, this.fitAOuter, this.fitBOuter, - this.fitInner, 4)} - ${ - makeMatMulPackedVec4Source( - this.elementsPerThread, this.tileAOuter, this.tileBOuter, - this.tileInner, 4, this.transposeA)} - `; - return userCode; - } -} diff --git a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts index e206b70ef4e..3e5fcfcb6d1 100644 --- a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts @@ -18,7 +18,7 @@ import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core'; import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; import {getMainHeaderString, WebGPUProgram} from './webgpu_program'; -import {computeDispatch, computeWorkGroupSizeForMatMul} from './webgpu_util'; +import {computeDispatch, computeWorkGroupInfoForMatMul} from './webgpu_util'; export function matMulReadFnSource( batchAEqualOne: boolean, batchBEqualOne: boolean, transposeA: boolean, @@ -107,11 +107,160 @@ export function matMulReadWriteFnSource( `; } +const writeDataToSubAVec4Snippet = (transpose: boolean) => { + if (transpose) { + return ` + mm_Asub[inputRow][inputCol] = mm_readA(batch, + kStart + inputRow, + globalRowStart / InnerElementSize + inputCol); + `; + + } else { + return ` + mm_Asub[inputRow][inputCol] = mm_readA(batch, + globalRow + innerRow, + kStart / InnerElementSize + inputCol); + `; + } +}; + +const calculateResultSnippet = + (transposeA: boolean, innerElementSize: number) => { + if (transposeA) { + return ` + let ACached0 = mm_Asub[k * InnerElementSize][localRow]; + let ACached1 = mm_Asub[k * InnerElementSize + 1][localRow]; + let ACached2 = mm_Asub[k * InnerElementSize + 2][localRow]; + ${ + innerElementSize === 3 ? + '' : + 'let ACached3 = mm_Asub[k * InnerElementSize + 3][localRow];'} + for (var i = 0; i < RowPerThread; i = i + 1) { + acc[i] = BCached0 * ACached0[i] + acc[i]; + acc[i] = BCached1 * ACached1[i] + acc[i]; + acc[i] = BCached2 * ACached2[i] + acc[i]; + ${ + innerElementSize === 3 ? + '' : + 'acc[i] = BCached3 * ACached3[i] + acc[i];'} + }`; + } else { + return ` + for (var i = 0; i < RowPerThread; i = i + 1) { + let ACached = mm_Asub[tileRow + i][k]; + acc[i] = BCached0 * ACached.x + acc[i]; + acc[i] = BCached1 * ACached.y + acc[i]; + acc[i] = BCached2 * ACached.z + acc[i]; + ${ + innerElementSize === 3 ? '' : + 'acc[i] = BCached3 * ACached.w + acc[i];'} + }`; + } + }; + +export function makeMatMulPackedVec4Source( + workPerThread: number[], workGroupSize: [number, number, number], + transposeA = false, tileInner = 32, splitK = false, + isVectorA = false): string { + const tileAOuter = workGroupSize[1] * workPerThread[1]; + const tileBOuter = workGroupSize[0] * workPerThread[0]; + const tileAWidth = transposeA ? tileAOuter : tileInner; + const tileAHight = transposeA ? tileInner : tileAOuter; + const innerElementSize = tileAWidth / workGroupSize[0]; + const rowPerThreadB = tileInner / workGroupSize[1]; + util.assert( + ((transposeA && innerElementSize === 4 && workPerThread[1] === 4) || + (!transposeA && (innerElementSize === 3 || innerElementSize === 4))) && + tileAWidth % workGroupSize[0] === 0 && + tileInner % workGroupSize[1] === 0 && workPerThread[0] === 4, + () => `If transposeA ${transposeA} is true, innerElementSize ${ + innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4. + Otherwise, innerElementSize ${innerElementSize} must be 3 or 4. + tileAWidth ${tileAWidth} must be divisible by workGroupSize[0]${ + workGroupSize[0]}. tileInner ${ + tileInner} must be divisible by workGroupSize[1] ${ + workGroupSize[1]}. ColPerThread ${workPerThread[0]} must be 4.`); + return ` + var mm_Asub : array, ${ + tileAWidth / innerElementSize}>, ${tileAHight}>; + var mm_Bsub : array, ${ + tileBOuter / workPerThread[0]}>, ${tileInner}>; + + const RowPerThread = ${workPerThread[1]}; + const ColPerThread = ${workPerThread[0]}; + const InnerElementSize = ${innerElementSize}; + const TileInner = ${tileInner}; + + @compute @workgroup_size(workGroupSizeX, workGroupSizeY, workGroupSizeZ) + fn main(@builtin(local_invocation_id) LocalId : vec3, + @builtin(global_invocation_id) GlobalId : vec3, + @builtin(num_workgroups) NumWorkgroups: vec3, + @builtin(workgroup_id) workgroupId: vec3) { + localId = LocalId; + globalId = GlobalId; + numWorkgroups = NumWorkgroups; + + let localRow = i32(localId.y); + let tileRow = ${isVectorA ? '0' : 'localRow * RowPerThread'}; + let tileCol = i32(localId.x); + + let globalRow = ${isVectorA ? '0' : 'i32(globalId.y) * RowPerThread'}; + let globalCol = i32(globalId.x); + let batch = ${splitK ? '0' : 'i32(globalId.z)'}; + let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; + + let numTiles = ${splitK ? '1' : '(uniforms.dimInner - 1) / TileInner + 1'}; + var kStart = ${splitK ? 'i32(globalId.z) * TileInner' : '0'}; + + var acc: array, RowPerThread>; + + // Loop over shared dimension. + let tileRowB = localRow * ${rowPerThreadB}; + for (var t = 0; t < numTiles; t = t + 1) { + // Load one tile of A into local memory. + for (var innerRow = 0; innerRow < RowPerThread; innerRow = innerRow + 1) { + let inputRow = tileRow + innerRow; + let inputCol = tileCol; + ${writeDataToSubAVec4Snippet(transposeA)} + } + + // Load one tile of B into local memory. + for (var innerRow = 0; innerRow < ${ + rowPerThreadB}; innerRow = innerRow + 1) { + let inputRow = tileRowB + innerRow; + let inputCol = tileCol; + mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol); + } + kStart = kStart + TileInner; + workgroupBarrier(); + + // Compute acc values for a single thread. + for (var k = 0; k < TileInner / InnerElementSize; k = k + 1) { + let BCached0 = mm_Bsub[k * InnerElementSize][tileCol]; + let BCached1 = mm_Bsub[k * InnerElementSize + 1][tileCol]; + let BCached2 = mm_Bsub[k * InnerElementSize + 2][tileCol]; + ${ + innerElementSize === 3 ? + '' : + 'let BCached3 = mm_Bsub[k * InnerElementSize + 3][tileCol];'} + + ${calculateResultSnippet(transposeA, innerElementSize)} + } + + workgroupBarrier(); + } + + for (var innerRow = 0; innerRow < RowPerThread; innerRow = innerRow + 1) { + mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]); + } + }`; +} + const writeDataToSubASnippet = (transpose: boolean) => { if (transpose) { return ` mm_Asub[inputRow][inputCol] = mm_readA(batch, - t * TileInner + inputRow, + kStart + inputRow, globalRowStart + inputCol); `; @@ -119,7 +268,7 @@ const writeDataToSubASnippet = (transpose: boolean) => { return ` mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRowStart + inputRow, - t * TileInner + inputCol); + kStart + inputCol); `; } }; @@ -132,7 +281,7 @@ const readDataFromSubASnippet = (transposeA: boolean) => { export function makeMatMulPackedSource( workPerThread: number[], workGroupSize: [number, number, number], - transposeA = false, tileInner = 32): string { + transposeA = false, tileInner = 32, splitK = false): string { const tileAOuter = workPerThread[1] * workGroupSize[1]; const tileBOuter = workPerThread[0] * workGroupSize[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -170,10 +319,12 @@ export function makeMatMulPackedSource( let globalRow = i32(globalId.y) * RowPerThread; let globalCol = i32(globalId.x) * ColPerThread; - let batch = i32(globalId.z); + let batch = ${splitK ? '0' : 'i32(globalId.z)'}; let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = (uniforms.dimInner - 1) / TileInner + 1; + let numTiles = ${ + splitK ? '1' : '(uniforms.dimInner - 1) / TileInner + 1'}; + var kStart = ${splitK ? 'i32(globalId.z) * TileInner' : '0'}; var acc : array, RowPerThread>; @@ -207,11 +358,11 @@ export function makeMatMulPackedSource( let inputRow = tileRowB + innerRow; let inputCol = tileCol + innerCol; mm_Bsub[inputRow][inputCol] = mm_readB(batch, - t * TileInner + inputRow, + kStart + inputRow, globalCol + innerCol); } } - + kStart = kStart + TileInner; workgroupBarrier(); // Compute acc values for a single thread. @@ -308,10 +459,10 @@ export class MatMulPackedProgram implements WebGPUProgram { shaderKey: string; dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; - workPerThread: number; variableNames = ['A', 'B']; uniforms = `dimAOuter : i32, dimBOuter : i32, dimInner : i32,`; - workGroupSize: [number, number, number] = [16, 16, 1]; + workGroupSize: [number, number, number]; + elementsPerThread: [number, number, number]; transposeA: boolean; transposeB: boolean; addBias: boolean; @@ -323,35 +474,38 @@ export class MatMulPackedProgram implements WebGPUProgram { fitBOuter: boolean; fitInner: boolean; tileInner: number; + isVectorA: boolean; + isVec4: boolean; constructor( aShape: [number, number, number], outputShape: [number, number, number], - workPerThread: number, batchAEqualOne: boolean, batchBEqualOne: boolean, - transposeA = false, transposeB = false, bias: TensorInfo = null, + batchAEqualOne: boolean, batchBEqualOne: boolean, transposeA = false, + transposeB = false, bias: TensorInfo = null, activation: backend_util.Activation = null, preluActivationWeights: TensorInfo = null) { this.outputShape = outputShape; this.dispatchLayout = {x: [2], y: [1], z: [0]}; const dimInner = transposeA ? aShape[1] : aShape[2]; - this.workGroupSize = - computeWorkGroupSizeForMatMul(outputShape[1], dimInner, outputShape[2]); - if (outputShape[1] === 1 || outputShape[2] === 1) { - workPerThread = 1; + this.isVec4 = ((dimInner % 4 === 0 && !transposeA) || + (outputShape[1] % 4 === 0 && transposeA)) && + outputShape[2] % 4 === 0 && !transposeB; + this.isVectorA = outputShape[1] === 1 && !transposeA; + + if (!this.isVec4 && this.isVectorA) { + // For makeVectorMatrixProductSource + this.elementsPerThread = [1, 1, 1]; + this.workGroupSize = [32, 1, 1]; + } else { + const workGroupInfo = computeWorkGroupInfoForMatMul( + outputShape[1], dimInner, outputShape[2], transposeA); + this.workGroupSize = workGroupInfo.workGroupSize; + this.elementsPerThread = workGroupInfo.elementsPerThread; } + this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize, - [workPerThread, workPerThread, 1]); - // If dispaching number is one, it means only one work group is running. - // For modern GPUs, it supports multiple work groups running in parallel. - // So there may be some idle hardware threads. - // In this case, we prefer to reduce the work per thread and improve the - // thread utilization - if (util.arraysEqual(this.dispatch, [1, 1, 1])) { - workPerThread = 1; - this.dispatch = computeDispatch( - this.dispatchLayout, this.outputShape, this.workGroupSize, - [workPerThread, workPerThread, 1]); - } + this.elementsPerThread); + const addBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; if (addBias) { @@ -362,7 +516,6 @@ export class MatMulPackedProgram implements WebGPUProgram { this.variableNames.push('preluActivationWeights'); } - this.workPerThread = workPerThread; this.transposeA = transposeA; this.transposeB = transposeB; this.addBias = addBias; @@ -372,20 +525,22 @@ export class MatMulPackedProgram implements WebGPUProgram { this.batchBEqualOne = batchBEqualOne; [this.fitAOuter, this.fitBOuter, this.fitInner] = this.getShapeFit(outputShape[1], outputShape[2], dimInner); - this.shaderKey = `matMulPacked_${this.workPerThread}_${transposeA}_${ + this.shaderKey = `matMulPacked_${this.elementsPerThread}_${transposeA}_${ transposeB}_${this.activation}_${this.fitAOuter}_${this.fitBOuter}_${ - this.fitInner}_${this.outputShape[1] > 1}_${this.batchAEqualOne}_${ - this.batchBEqualOne}`; + this.fitInner}_${this.isVec4}_${this.isVectorA}_${ + this.batchAEqualOne}_${this.batchBEqualOne}`; } getShapeFit(dimAOuter: number, dimBOuter: number, dimInner: number): boolean[] { - const tileAOuter = this.workGroupSize[1] * this.workPerThread; - const tileBOuter = this.workGroupSize[0] * this.workPerThread; - this.tileInner = 32; + const tileAOuter = this.workGroupSize[1] * this.elementsPerThread[1]; + const tileBOuter = this.workGroupSize[0] * this.elementsPerThread[0]; - if (this.outputShape[1] === 1) { + if (!this.isVec4 && this.isVectorA) { + // For makeVectorMatrixProductSource this.tileInner = this.workGroupSize[0] * 4; + } else { + this.tileInner = tileBOuter; } const fitAOuter = dimAOuter % tileAOuter === 0; @@ -396,19 +551,26 @@ export class MatMulPackedProgram implements WebGPUProgram { getUserCode(): string { const userCode = ` - ${activationFnSnippet(this.activation, this.hasPreluActivationWeights)} + ${ + activationFnSnippet( + this.activation, this.hasPreluActivationWeights, this.isVec4)} ${ matMulReadWriteFnSource( this.addBias, this.activation, this.batchAEqualOne, this.batchBEqualOne, false /* transposeA is implemented in makeMatMulPackedSource */, - this.transposeB, this.fitAOuter, this.fitBOuter, this.fitInner)} + this.transposeB, this.fitAOuter, this.fitBOuter, this.fitInner, + this.isVec4 ? 4 : 1)} ${ - this.outputShape[1] > 1 ? - makeMatMulPackedSource( - [this.workPerThread, this.workPerThread, 1], this.workGroupSize, - this.transposeA, this.tileInner) : - makeVectorMatrixProductSource(this.workGroupSize, this.transposeA)} + this.isVec4 ? + makeMatMulPackedVec4Source( + this.elementsPerThread, this.workGroupSize, this.transposeA, + this.tileInner, false, this.isVectorA) : + (this.isVectorA ? makeVectorMatrixProductSource( + this.workGroupSize, this.transposeA) : + makeMatMulPackedSource( + this.elementsPerThread, this.workGroupSize, + this.transposeA, this.tileInner))} `; return userCode; } diff --git a/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts b/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts index 6b484d838e2..ac6ea700f64 100644 --- a/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts @@ -17,9 +17,9 @@ import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core'; -import {activationFnSnippet, biasActivationSnippet} from './activation_util'; -import {matMulReadFnSource} from './matmul_packed_webgpu'; -import {getMainHeaderAndGlobalIndexString, getMainHeaderString, WebGPUProgram} from './webgpu_program'; +import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source, matMulReadFnSource} from './matmul_packed_webgpu'; +import {getMainHeaderAndGlobalIndexString, WebGPUProgram} from './webgpu_program'; import {computeDispatch, flatDispatchLayout} from './webgpu_util'; export class MatMulSplitKProgram implements WebGPUProgram { @@ -36,6 +36,7 @@ export class MatMulSplitKProgram implements WebGPUProgram { atomic = true; batchAEqualOne: boolean; batchBEqualOne: boolean; + isVec4 = false; tileInner = 32; constructor( @@ -47,13 +48,19 @@ export class MatMulSplitKProgram implements WebGPUProgram { () => 'MatMulSplitKProgram only supports batch = 1.'); this.outputShape = outputShape; this.dispatchLayout = {x: [2], y: [1], z: [0, 3]}; + this.isVec4 = (transposeA && this.outputShape[1] % 4 === 0 || + !transposeA && dimInner % 4 === 0) && + this.outputShape[2] % 4 === 0; this.elementsPerThread = [4, 4, this.tileInner]; - if (this.outputShape[1] < 16) { - this.elementsPerThread[1] = 1; - } - if (this.outputShape[2] < 16) { - this.elementsPerThread[0] = 1; + if (!this.isVec4) { + if (this.outputShape[1] < 16) { + this.elementsPerThread[1] = 1; + } + if (this.outputShape[2] < 16) { + this.elementsPerThread[0] = 1; + } } + this.dispatch = computeDispatch( this.dispatchLayout, [ @@ -66,133 +73,59 @@ export class MatMulSplitKProgram implements WebGPUProgram { this.transposeB = transposeB; this.batchAEqualOne = batchAEqualOne; this.batchBEqualOne = batchBEqualOne; - this.shaderKey = `matMulSplitK_${transposeA}_${transposeB}_${ - batchAEqualOne}_${batchBEqualOne}_${this.elementsPerThread}`; + this.shaderKey = + `matMulSplitK_${transposeA}_${transposeB}_${batchAEqualOne}_${ + batchBEqualOne}_${this.elementsPerThread}_${this.isVec4}`; } getUserCode(): string { // atomicAdd only supports uint/int type. For float, we use // atomicCompareExchangeWeak to simulate. - const atomicAddSnippet = ` - var oldValue = atomicLoad(&(result[flatIndex])); - var exchanged = false; - for (; !exchanged;) { - let newValueF32 = bitcast(oldValue) + value; - let newValue = bitcast(newValueF32); - let res = atomicCompareExchangeWeak(&(result[flatIndex]), oldValue, newValue); - oldValue = res.old_value; - exchanged = res.exchanged; - } - `; + const atomicAddSnippet = (component: number) => { + return ` + for (var i = 0; i < ${component}; i = i + 1) + { + var oldValue = atomicLoad(&(result[flatIndex + i])); + var exchanged = false; + for (; !exchanged;) { + let newValueF32 = bitcast(oldValue) + ${ + component > 1 ? 'value[i]' : 'value'}; + let newValue = bitcast(newValueF32); + let res = atomicCompareExchangeWeak(&(result[flatIndex + i]), oldValue, newValue); + oldValue = res.old_value; + exchanged = res.exchanged; + } + } + `; + }; + const component = this.isVec4 ? 4 : 1; const userCode = ` ${ matMulReadFnSource( - this.batchAEqualOne, this.batchBEqualOne, this.transposeA, - this.transposeB)} - fn mm_write(batch: i32, row : i32, col : i32, valueIn : f32) { + this.batchAEqualOne, this.batchBEqualOne, false, this.transposeB, + false, false, false, component)} + fn mm_write(batch: i32, row : i32, colIn : i32, value : ${ + typeSnippet(component)}) { + let col = colIn * ${component}; if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { let coords = vec3(batch, row, col); let flatIndex = getOutputIndexFromCoords(coords); - var value = valueIn; // The problem is that we should initialize output to zero before using. // Otherwise, the original value will be added to the result. - ${atomicAddSnippet} + ${atomicAddSnippet(component)} } } - - ${this.makeMatMulSplitKSource()} + ${ + this.isVec4 ? makeMatMulPackedVec4Source( + this.elementsPerThread, this.workGroupSize, + this.transposeA, this.tileInner, true) : + makeMatMulPackedSource( + this.elementsPerThread, this.workGroupSize, + this.transposeA, this.tileInner, true)} `; return userCode; } - - makeMatMulSplitKSource(): string { - const tileAOuter = this.workGroupSize[1] * this.elementsPerThread[1]; - const tileBOuter = this.workGroupSize[0] * this.elementsPerThread[0]; - const rowPerThread = this.elementsPerThread[1]; - const colPerThread = this.elementsPerThread[0]; - const colPerThreadA = this.tileInner / this.workGroupSize[0]; - const rowPerThreadB = this.tileInner / this.workGroupSize[1]; - util.assert( - this.tileInner % this.workGroupSize[0] === 0 && - this.tileInner % this.workGroupSize[1] === 0, - () => - `tileInner ${this.tileInner} must be divisible by workGroupSize[0]${ - this.workGroupSize[0]} and workGroupSize[1]${ - this.workGroupSize[1]}`); - return ` - var mm_Asub : array, ${ - tileAOuter}>; - var mm_Bsub : array, ${ - this.tileInner}>; - ${getMainHeaderString()} - let tileRow = i32(localId.y) * ${rowPerThread}; - let tileCol = i32(localId.x) * ${colPerThread}; - - let globalRow = i32(globalId.y) * ${rowPerThread}; - let globalCol = i32(globalId.x) * ${colPerThread}; - let batch = 0; - let kStart = i32(globalId.z) * ${this.tileInner}; - - // Load one tile of A into local memory. - let tileColA = i32(localId.x) * ${colPerThreadA}; - for (var innerRow = 0; innerRow < ${ - rowPerThread}; innerRow = innerRow + 1) { - for (var innerCol = 0; innerCol < ${ - colPerThreadA}; innerCol = innerCol + 1) { - let inputRow = tileRow + innerRow; - let inputCol = tileColA + innerCol; - mm_Asub[inputRow][inputCol] = mm_readA(${ - this.batchAEqualOne ? 0 : 'batch'}, - globalRow + innerRow, - kStart + inputCol); - } - } - // Load one tile of B into local memory. - let tileRowB = i32(localId.y) * ${rowPerThreadB}; - for (var innerRow = 0; innerRow < ${ - rowPerThreadB}; innerRow = innerRow + 1) { - for (var innerCol = 0; innerCol < ${ - colPerThread}; innerCol = innerCol + 1) { - let inputRow = tileRowB + innerRow; - let inputCol = tileCol + innerCol; - mm_Bsub[inputRow][inputCol] = mm_readB(${ - this.batchBEqualOne ? 0 : 'batch'}, - kStart + inputRow, - globalCol + innerCol); - } - } - - workgroupBarrier(); - - var acc : array, ${rowPerThread}>; - // Loop over shared dimension. Compute acc values for a single thread. - for (var k = 0; k < ${this.tileInner}; k = k + 1) { - var BCached : array; - for (var inner = 0; inner < ${colPerThread}; inner = inner + 1) { - BCached[inner] = mm_Bsub[k][tileCol + inner]; - } - - for (var innerRow = 0; innerRow < ${ - rowPerThread}; innerRow = innerRow + 1) { - let ACached = mm_Asub[tileRow + innerRow][k]; - for (var innerCol = 0; innerCol < ${ - colPerThread}; innerCol = innerCol + 1) { - acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol]; - } - } - } - - for (var innerRow = 0; innerRow < ${ - rowPerThread}; innerRow = innerRow + 1) { - for (var innerCol = 0; innerCol < ${ - colPerThread}; innerCol = innerCol + 1) { - mm_write(batch, globalRow + innerRow, globalCol + innerCol, acc[innerRow][innerCol]); - } - } - } - `; - } } export class BiasActivationProgram implements WebGPUProgram { diff --git a/tfjs-backend-webgpu/src/matmul_test.ts b/tfjs-backend-webgpu/src/matmul_test.ts index bb0dd4bd3ce..e76f1c19575 100644 --- a/tfjs-backend-webgpu/src/matmul_test.ts +++ b/tfjs-backend-webgpu/src/matmul_test.ts @@ -895,7 +895,8 @@ function matmulTest(programType: MatMulProgramType) { 1164, 1281, 1375, 1472, 1217, 1327 ]); }); - it('matMul MatMulPackedVec4Program outputShape[1] > 1', async () => { + + it('A x B vec4', async () => { const a = tf.tensor3d( [ 2, 1, 3, 2, 1, 1, 1, 5, 6, 7, 8, 1, 2, 2, 1, 9, 11, 10, 1, @@ -989,7 +990,7 @@ function matmulTest(programType: MatMulProgramType) { ]); }); - it('matMul MatMulPackedVec4Program outputShape[1] == 1', async () => { + it('A x B vec4 A is a vector', async () => { const a = tf.tensor3d([2, 1, 3, 2], [1, 1, 4]); const b = tf.tensor3d( @@ -1011,8 +1012,7 @@ function matmulTest(programType: MatMulProgramType) { ]); }); - // Below cases are from mat_mul_test.ts in tfjs-core. - it('A^t x B', async () => { + it('A^t x B vec4', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); const b = tf.tensor2d([1, 0, 2, 4, 3, 0, 5, 6], [2, 4]); @@ -1025,6 +1025,163 @@ function matmulTest(programType: MatMulProgramType) { test_util.expectArraysClose(result, expected); }); + it('fused A x B vec4 with relu', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const b = tf.tensor2d( + [0, 1, -3, 2, 2, 1, 1, 0, 2, 4, 3, 0, 5, -6, 7, -8], [4, 4]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'relu'}); + + expect(c.shape).toEqual([2, 4]); + expectArraysClose(await c.data(), [30, 0, 36, 0, 66, 0, 68, 0]); + }); + + it('fused A x B vec4 with elu', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const b = tf.tensor2d( + [0, 1, -3, 2, 2, 1, 1, 0, 2, 4, 3, 0, 5, -6, 7, -8], [4, 4]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'elu'}); + + expect(c.shape).toEqual([2, 4]); + expectArraysClose( + await c.data(), [30, -0.9999, 36, -1, 66, -0.9999, 68, -1]); + }); + + it('fused A x B vec4 with relu6', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const b = tf.tensor2d( + [0, 1, -3, 2, 2, 1, 1, 0, 2, 4, 3, 0, 5, -6, 7, -8], [4, 4]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'relu6'}); + + expect(c.shape).toEqual([2, 4]); + expectArraysClose(await c.data(), [6, 0, 6, 0, 6, 0, 6, 0]); + }); + + it('fused A x B vec4 with prelu', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const b = tf.tensor2d( + [0, 1, -3, 2, 2, 1, 1, 0, 2, 4, 3, 0, 5, -6, 7, -8], [4, 4]); + const alpha = tf.tensor2d([0.5, 0.5, 0.5, 0.5], [1, 4]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul({ + a, + b, + transposeA, + transposeB, + bias: null, + activation: 'prelu', + preluActivationWeights: alpha + }); + + expect(c.shape).toEqual([2, 4]); + expectArraysClose(await c.data(), [30, -4.5, 36, -15, 66, -4.5, 68, -27]); + }); + + it('fused A x B vec4 with leakyrelu', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const b = tf.tensor2d( + [0, 1, -3, 2, 2, 1, 1, 0, 2, 4, 3, 0, 5, -6, 7, -8], [4, 4]); + const alpha = 0.3; + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul({ + a, + b, + transposeA, + transposeB, + bias: null, + activation: 'leakyrelu', + leakyreluAlpha: alpha + }); + + expect(c.shape).toEqual([2, 4]); + expectArraysClose(await c.data(), [ + 30, -2.700000047683716, 36, -9, 66, -2.700000047683716, 68, + -16.200000762939453 + ]); + }); + + it('fused A x B vec4 with leakyrelu not provided.', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const b = tf.tensor2d( + [0, 1, -3, 2, 2, 1, 1, 0, 2, 4, 3, 0, 5, -6, 7, -8], [4, 4]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'leakyrelu'}); + + expect(c.shape).toEqual([2, 4]); + // leakyRelu should use default alpha=0.2. + expectArraysClose(await c.data(), [ + 30, -1.8000000715255737, 36, -6, 66, -1.8000000715255737, 68, + -10.800000190734863 + ]); + }); + + it('fused A x B with sigmoid', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const b = tf.tensor2d( + [0, 1, -3, 2, 2, 1, 1, 0, 2, 4, 3, 0, 5, -6, 7, -8], [4, 4]); + const transposeA = false; + const transposeB = false; + + const c = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: null, activation: 'sigmoid'}); + + expect(c.shape).toEqual([2, 4]); + expectArraysClose(await c.data(), [ + 1, 0.00012339462409727275, 1, 9.35763443186792e-14, 1, + 0.00012339462409727275, 1, 3.5326268130932535e-24 + ]); + }); + + it('fused A x B vec4 with 2d bias and relu', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const b = tf.tensor2d( + [0, 1, -3, 2, 2, 1, 1, 0, 2, 4, 3, 0, 5, -6, 7, -8], [4, 4]); + const c = tf.tensor2d([1, 1, 1, 1, 1, 1, 1, 1], [2, 4]); + const transposeA = false; + const transposeB = false; + + const d = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: 'relu'}); + + expect(d.shape).toEqual([2, 4]); + expectArraysClose(await d.data(), [31, 0, 37, 0, 67, 0, 69, 0]); + }); + + it('fused A x B vec4 with relu and broadcasted bias', async () => { + const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const b = tf.tensor2d( + [0, 1, -3, 2, 2, 1, 1, 0, 2, 4, 3, 0, 5, -6, 7, -8], [4, 4]); + const c = tf.tensor1d([1, 1, 1, 1]); + const act: tf.fused.Activation = 'relu'; + const transposeA = false; + const transposeB = false; + + const d = tf.fused.matMul( + {a, b, transposeA, transposeB, bias: c, activation: act}); + + expect(d.shape).toEqual([2, 4]); + expectArraysClose(await d.data(), [31, 0, 37, 0, 67, 0, 69, 0]); + }); + + // Below cases are from mat_mul_test.ts in tfjs-core. it('A x B', async () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]); @@ -1794,10 +1951,6 @@ function matmulBatchTest(programType: MatMulProgramType) { } for (let i = 0; i < MatMulProgramType.MatMulMax; i++) { - // TODO: Add tests for MatMulPackedVec4Program. - if (i === MatMulProgramType.MatMulPackedVec4Program) { - continue; - } describeWithFlags(`matmul ${MatMulProgramType[i]}`, ALL_ENVS, matmulTest(i)); // Skip MatMulSplitKProgram since it doesn't support batch > 1; if (i !== MatMulProgramType.MatMulSplitKProgram) { diff --git a/tfjs-backend-webgpu/src/webgpu_util.ts b/tfjs-backend-webgpu/src/webgpu_util.ts index eb4ce40e907..b067252b527 100644 --- a/tfjs-backend-webgpu/src/webgpu_util.ts +++ b/tfjs-backend-webgpu/src/webgpu_util.ts @@ -59,6 +59,37 @@ export function computeDispatch( return [dispatchX, dispatchY, dispatchZ]; } +export type WorkGroupInfo = { + workGroupSize: [number, number, number], + elementsPerThread: [number, number, number], +}; + +export function computeWorkGroupInfoForMatMul( + dimAOuter: number, dimInner: number, dimBOuter: number, + transposeA = false): WorkGroupInfo { + // These are experimental values. Usually, we need to adjust the work group + // size based on the input shapes to improve the EU occupancy. + // TODO: WebGPU limits the maximum allowed shared memory size as 16K. To make + // sure it doesn't exceed this limitations. Temporarily reduce the work group + // size to [8, 8, 1] and the work per thread size is [4, 4, 1]. But we should + // revisit it and find the balance between work group size and work per thread + // size. + const workGroupSize: [number, number, number] = [8, 8, 1]; + const elementsPerThread: [number, number, number] = [4, 4, 1]; + + if (!transposeA) { + if (dimAOuter <= 8) { + elementsPerThread[1] = 1; + } + + if (dimInner <= 16 && dimBOuter <= 16) { + workGroupSize[0] = 4; + } + } + + return {workGroupSize, elementsPerThread}; +} + export function computeWorkGroupSizeForConv2d( layout: {x: number[], y?: number[], z?: number[]}, outputShape: number[], isVec4 = false): [number, number, number] { @@ -86,25 +117,6 @@ export function computeWorkGroupSizeForConv2d( return [16, 16, 1]; } -export function computeWorkGroupSizeForMatMul( - dimAOuter: number, dimInner: number, - dimBOuter: number): [number, number, number] { - // These are experimental values. Usually, we need to adjust the work group - // size based on the input shapes to improve the EU occupancy. - // TODO: WebGPU limits the maximum allowed shared memory size as 16K. To make - // sure it doesn't exceed this limitations. Temporarily reduce the work group - // size to [8, 8, 1] and the work per thread size is [4, 4, 1]. But we should - // revisit it and find the balance between work group size and work per thread - // size. - if (dimAOuter === 1) { - return [32, 1, 1]; - } else if (dimBOuter === 1) { - return [1, 32, 1]; - } - - return [8, 8, 1]; -} - export function computeWorkPerThreadForConv2d( layout: {x: number[], y?: number[], z?: number[]}, outputShape: number[], isVec4 = false): [number, number, number] { @@ -162,7 +174,6 @@ export function isWebGPUSupported(): boolean { } export enum MatMulProgramType { - MatMulPackedVec4Program, MatMulReduceProgram, MatMulSplitKProgram, MatMulSmallOutputSizeProgram, From 89f59aa6a7dc793783b4519581cc27b4d9f558dc Mon Sep 17 00:00:00 2001 From: Linchenn <40653845+Linchenn@users.noreply.github.com> Date: Fri, 12 Aug 2022 13:14:39 -0700 Subject: [PATCH 2/3] [browserstack benchmark tool] batch benchmark code snippet (#6742) FEATURE * move code snippet * Update app.js * Update index.js --- e2e/benchmarks/browserstack-benchmark/app.js | 42 ++++++++++++++------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/e2e/benchmarks/browserstack-benchmark/app.js b/e2e/benchmarks/browserstack-benchmark/app.js index 088dde4fa6b..4104fa3c483 100644 --- a/e2e/benchmarks/browserstack-benchmark/app.js +++ b/e2e/benchmarks/browserstack-benchmark/app.js @@ -111,19 +111,35 @@ async function benchmarkAll(config) { for (backend of benchmarkInfo.backend) { for (model of benchmarkInfo.model) { - console.log( - `\nRunning ${model} model benchmarks over ${backend} backend...`); - const result = await benchmark({ - 'benchmark': { - 'model': model, - 'numRuns': benchmarkInfo.numRuns, - 'backend': backend, - 'codeSnippet': benchmarkInfo.codeSnippet || '', - 'setupCodeSnippetEnv': benchmarkInfo.setupCodeSnippetEnv || '' - }, - 'browsers': config.browsers - }); - allResults.push(result); + if (model === 'codeSnippet') { + for (codeSnippetPair of benchmarkInfo.codeSnippets) { + console.log( + `\nRunning codeSnippet benchmarks over ${backend} backend...`); + const result = await benchmark({ + 'benchmark': { + 'model': model, + 'numRuns': benchmarkInfo.numRuns, + 'backend': backend, + 'codeSnippet': codeSnippetPair.codeSnippet || '', + 'setupCodeSnippetEnv': codeSnippetPair.setupCodeSnippetEnv || '' + }, + 'browsers': config.browsers + }); + allResults.push(result); + } + } else { + console.log( + `\nRunning ${model} model benchmarks over ${backend} backend...`); + const result = await benchmark({ + 'benchmark': { + 'model': model, + 'numRuns': benchmarkInfo.numRuns, + 'backend': backend + }, + 'browsers': config.browsers + }); + allResults.push(result); + } } } console.log('\nAll benchmarks complete!'); From b02de70ec477d3fac953d59aff4e86623681a78c Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 16 Aug 2022 15:34:09 +0800 Subject: [PATCH 3/3] [webgpu] Add Atan2, IsNaN, Reciprocal (#6743) Typical use case: Atan2: FaceLandmarkDetection,attention_mesh. IsNan: ArPortraitDepth. Reciprocal: MoveNet-MultiPose. --- tfjs-backend-webgpu/src/binary_op_util.ts | 78 +++++++++---------- tfjs-backend-webgpu/src/kernels/Atan2.ts | 28 +++++++ tfjs-backend-webgpu/src/kernels/IsNaN.ts | 29 +++++++ tfjs-backend-webgpu/src/kernels/Reciprocal.ts | 28 +++++++ .../src/register_all_kernels.ts | 6 ++ tfjs-backend-webgpu/src/setup_test.ts | 9 ++- tfjs-backend-webgpu/src/unary_op_util.ts | 8 ++ tfjs-core/src/ops/binary_ops_test.ts | 30 +++++++ 8 files changed, 174 insertions(+), 42 deletions(-) create mode 100644 tfjs-backend-webgpu/src/kernels/Atan2.ts create mode 100644 tfjs-backend-webgpu/src/kernels/IsNaN.ts create mode 100644 tfjs-backend-webgpu/src/kernels/Reciprocal.ts diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index 80a822afe1b..c000197e86c 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -18,6 +18,7 @@ export enum BinaryOpType { MUL, ADD, + ATAN2, SUB, DIV, EQUAL, @@ -37,6 +38,31 @@ export enum BinaryOpType { COMPLEX_MULTIPLY_IMAG } +const CHECK_NAN_SNIPPET = ` + if (isnan(a)) { return a; } + if (isnan(b)) { return b; } + `; + +const CHECK_NAN_SNIPPET_VEC4_INNER = ` + if (isNaN.r) { + resultTemp.r = valueForNaN; + } + if (isNaN.g) { + resultTemp.g = valueForNaN; + } + if (isNaN.b) { + resultTemp.b = valueForNaN; + } + if (isNaN.a) { + resultTemp.a = valueForNaN; + } + `; + +const CHECK_NAN_SNIPPET_VEC4 = ` + let isNaN = isnanVec4(a) | isnanVec4(b); + ${CHECK_NAN_SNIPPET_VEC4_INNER} + `; + const ADD = 'return a + b;'; // (Ar + Ai)(Br + Bi) = // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr @@ -61,24 +87,6 @@ const LESS_EQUAL_VEC4 = 'return vec4(a <= b);'; const LOGICAL_AND = 'return f32(f32(a) >= 1.0 && f32(b) >= 1.0);'; const LOGICAL_AND_VEC4 = `return (vec4(a >= vec4(1.0)) * vec4(b >= vec4(1.0)));`; -const CHECK_NAN_SNIPPET = ` - if (isnan(a)) { return a; } - if (isnan(b)) { return b; } - `; -const CHECK_NAN_SNIPPET_VEC4 = ` - if (isNaN.r) { - resultTemp.r = uniforms.NAN; - } - if (isNaN.g) { - resultTemp.g = uniforms.NAN; - } - if (isNaN.b) { - resultTemp.b = uniforms.NAN; - } - if (isNaN.a) { - resultTemp.a = uniforms.NAN; - } - `; const INT_DIV = ` let s = sign(a) * sign(b); let ia = i32(round(a)); @@ -116,23 +124,11 @@ const NOT_EQUAL = ` return f32(a != b); `; const NOT_EQUAL_VEC4 = ` - var result = vec4(a != b); - var isANaN = isnanVec4(a); - var isBNaN = isnanVec4(b); - if (isANaN.r || isBNaN.r) { - result.r = 1.0; - } - if (isANaN.g || isBNaN.g) { - result.g = 1.0; - } - if (isANaN.b || isBNaN.b) { - result.b = 1.0; - } - if (isANaN.a || isBNaN.a) { - result.a = 1.0; - } + var resultTemp = vec4(a != b); + let valueForNaN = 1.0; + ${CHECK_NAN_SNIPPET_VEC4} - return result; + return resultTemp; `; const POW = ` if(a < 0.0 && floor(b) < b) { @@ -167,7 +163,8 @@ const POW_VEC4 = ` resultTemp.a = 1.0; } let isNaN = a < vec4(0.0) & floor(b) < b; - ${CHECK_NAN_SNIPPET_VEC4} + let valueForNaN = uniforms.NAN; + ${CHECK_NAN_SNIPPET_VEC4_INNER} return resultTemp; `; @@ -177,11 +174,12 @@ const PRELU_VEC4 = ` return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a); `; -function getMinMaxString(op: string, useVec4: boolean) { +function getBinaryWithNanString( + op: string, useVec4: boolean, valueForNaN = 'uniforms.NAN') { const checkNanSnippet = useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET; return useVec4 ? ` + let valueForNaN = ${valueForNaN}; var resultTemp = vec4(${op}(a, b)); - let isNaN = isnanVec4(a) | isnanVec4(b); ` + checkNanSnippet + ` return resultTemp; @@ -198,6 +196,8 @@ export function getBinaryOpString( return MUL; case BinaryOpType.ADD: return ADD; + case BinaryOpType.ATAN2: + return getBinaryWithNanString('atan2', useVec4); case BinaryOpType.SUB: return SUB; case BinaryOpType.DIV: @@ -223,9 +223,9 @@ export function getBinaryOpString( case BinaryOpType.PRELU: return useVec4 ? PRELU_VEC4 : PRELU; case BinaryOpType.MAX: - return getMinMaxString('max', useVec4); + return getBinaryWithNanString('max', useVec4); case BinaryOpType.MIN: - return getMinMaxString('min', useVec4); + return getBinaryWithNanString('min', useVec4); case BinaryOpType.POW: return useVec4 ? POW_VEC4 : POW; case BinaryOpType.COMPLEX_MULTIPLY_REAL: diff --git a/tfjs-backend-webgpu/src/kernels/Atan2.ts b/tfjs-backend-webgpu/src/kernels/Atan2.ts new file mode 100644 index 00000000000..ca4c4884f3e --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/Atan2.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Atan2, KernelConfig} from '@tensorflow/tfjs-core'; +import {BinaryOpType} from '../binary_op_util'; +import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils'; + +export const atan2 = binaryKernelFunc({opType: BinaryOpType.ATAN2}); + +export const atan2Config: KernelConfig = { + kernelName: Atan2, + backendName: 'webgpu', + kernelFunc: atan2 +}; diff --git a/tfjs-backend-webgpu/src/kernels/IsNaN.ts b/tfjs-backend-webgpu/src/kernels/IsNaN.ts new file mode 100644 index 00000000000..e70592c290a --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/IsNaN.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {IsNan, KernelConfig} from '@tensorflow/tfjs-core'; +import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils'; +import {UnaryOpType} from '../unary_op_util'; + +export const isNaN = + unaryKernelFunc({opType: UnaryOpType.IS_NAN, dtype: 'bool'}); + +export const isNaNConfig: KernelConfig = { + kernelName: IsNan, + backendName: 'webgpu', + kernelFunc: isNaN +}; diff --git a/tfjs-backend-webgpu/src/kernels/Reciprocal.ts b/tfjs-backend-webgpu/src/kernels/Reciprocal.ts new file mode 100644 index 00000000000..f5cc4c88b5e --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/Reciprocal.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, Reciprocal} from '@tensorflow/tfjs-core'; +import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils'; +import {UnaryOpType} from '../unary_op_util'; + +export const reciprocal = unaryKernelFunc({opType: UnaryOpType.RECIPROCAL}); + +export const reciprocalConfig: KernelConfig = { + kernelName: Reciprocal, + backendName: 'webgpu', + kernelFunc: reciprocal +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index bc8b6fa415f..e4a2093a73b 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -22,6 +22,7 @@ import {addConfig} from './kernels/Add'; import {addNConfig} from './kernels/AddN'; import {argMaxConfig} from './kernels/ArgMax'; import {argMinConfig} from './kernels/ArgMin'; +import {atan2Config} from './kernels/Atan2'; import {avgPoolConfig} from './kernels/AvgPool'; import {batchMatMulConfig} from './kernels/BatchMatMul'; import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND'; @@ -59,6 +60,7 @@ import {greaterConfig} from './kernels/Greater'; import {greaterEqualConfig} from './kernels/GreaterEqual'; import {identityConfig} from './kernels/Identity'; import {imagConfig} from './kernels/Imag'; +import {isNaNConfig} from './kernels/IsNaN'; import {leakyReluConfig} from './kernels/LeakyRelu'; import {lessConfig} from './kernels/Less'; import {lessEqualConfig} from './kernels/LessEqual'; @@ -86,6 +88,7 @@ import {prodConfig} from './kernels/Prod'; import {rangeConfig} from './kernels/Range'; import {realConfig} from './kernels/Real'; import {realDivConfig} from './kernels/RealDiv'; +import {reciprocalConfig} from './kernels/Reciprocal'; import {reluConfig} from './kernels/Relu'; import {relu6Config} from './kernels/Relu6'; import {reshapeConfig} from './kernels/Reshape'; @@ -126,6 +129,7 @@ const kernelConfigs: KernelConfig[] = [ addNConfig, argMaxConfig, argMinConfig, + atan2Config, avgPoolConfig, batchMatMulConfig, batchToSpaceNDConfig, @@ -163,6 +167,7 @@ const kernelConfigs: KernelConfig[] = [ greaterEqualConfig, identityConfig, imagConfig, + isNaNConfig, leakyReluConfig, lessConfig, lessEqualConfig, @@ -190,6 +195,7 @@ const kernelConfigs: KernelConfig[] = [ rangeConfig, realConfig, realDivConfig, + reciprocalConfig, reluConfig, relu6Config, reshapeConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index bf4ed8e11e5..58e4ee0fd7d 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -40,6 +40,12 @@ const TEST_FILTERS: TestFilter[] = [ 'gradient', // Step kernel not yet implemented. ] }, + { + startsWith: 'atan2 ', + excludes: [ + 'gradient', // Not yet implemented. + ] + }, { startsWith: 'avgPool ', excludes: [ @@ -259,7 +265,6 @@ const TEST_FILTERS: TestFilter[] = [ 'any webgpu ', 'asin ', 'asinh ', - 'atan2 ', 'atanh ', 'avgPool3d ', 'avgPool3dBackprop ', @@ -280,7 +285,6 @@ const TEST_FILTERS: TestFilter[] = [ 'IRFFT ', 'isFinite ', 'isInf ', - 'isNaN ', 'linspace ', 'localResponseNormalization ', 'log1p ', @@ -297,7 +301,6 @@ const TEST_FILTERS: TestFilter[] = [ 'oneHot ', 'confusionMatrix ', // oneHot 'poolBackprop ', - 'reciprocal ', 'reverse1d ', 'reverse2d ', 'reverse3d ', diff --git a/tfjs-backend-webgpu/src/unary_op_util.ts b/tfjs-backend-webgpu/src/unary_op_util.ts index bcf7849d264..e49c2b96dd7 100644 --- a/tfjs-backend-webgpu/src/unary_op_util.ts +++ b/tfjs-backend-webgpu/src/unary_op_util.ts @@ -24,6 +24,7 @@ export enum UnaryOpType { EXP, EXPM1, FLOOR, + IS_NAN, LINEAR, LOG, LOGICAL_NOT, @@ -31,6 +32,7 @@ export enum UnaryOpType { RELU, RELU6, LEAKYRELU, + RECIPROCAL, RSQRT, SIN, SINH, @@ -68,6 +70,7 @@ const ELU_VEC4 = ` `; const EXP = `return exp(a);`; const FLOOR = `return floor(a);`; +const IS_NAN = `return f32(isnan(a));`; const LINEAR = `return a;`; const LOG = `if (a < 0.0) { return 1.0/0.0; } return log(a);`; @@ -78,6 +81,7 @@ const LEAKYRELU_VEC4 = ` let aLessThanZero = vec4(a < vec4(0.0)); return (aLessThanZero * (uniforms.alpha * a)) + ((vec4(1.0) - aLessThanZero) * a); `; +const RECIPROCAL = `return 1.0 / a;`; const RELU = `return select(a, 0.0, a < 0.0);`; const RELU6 = 'return clamp(a, 0.0, 6.0);'; const RELU6_VEC4 = @@ -118,6 +122,8 @@ export function getUnaryOpString(type: UnaryOpType, useVec4?: boolean): string { return EXPM1; case UnaryOpType.FLOOR: return FLOOR; + case UnaryOpType.IS_NAN: + return IS_NAN; case UnaryOpType.LINEAR: return LINEAR; case UnaryOpType.LOG: @@ -128,6 +134,8 @@ export function getUnaryOpString(type: UnaryOpType, useVec4?: boolean): string { return NEG; case UnaryOpType.LEAKYRELU: return useVec4 ? LEAKYRELU_VEC4 : LEAKYRELU; + case UnaryOpType.RECIPROCAL: + return RECIPROCAL; case UnaryOpType.RELU: return useVec4 ? RELU_VEC4 : RELU; case UnaryOpType.RELU6: diff --git a/tfjs-core/src/ops/binary_ops_test.ts b/tfjs-core/src/ops/binary_ops_test.ts index 47710c301cd..56f1de17031 100644 --- a/tfjs-core/src/ops/binary_ops_test.ts +++ b/tfjs-core/src/ops/binary_ops_test.ts @@ -1083,6 +1083,36 @@ describeWithFlags('atan2', ALL_ENVS, () => { expectArraysClose(await r.data(), expected); }); + it('atan2 vec4 NaNs', async () => { + const aValues = [1.0, 2.0, 3.0, 4.0]; + const cValues = [3.0, NaN, 3.0, 4.0]; + const a = tf.tensor2d(aValues, [4, 1]); + const c = tf.tensor2d(cValues, [4, 1]); + + const r = tf.atan2(a, c); + const expected = []; + + for (let i = 0; i < a.size; i++) { + expected[i] = Math.atan2(aValues[i], cValues[i]); + } + expectArraysClose(await r.data(), expected); + }); + + it('atan2 vec4 all NaNs', async () => { + const aValues = [NaN, 2.0, NaN, NaN]; + const cValues = [3.0, NaN, 3.0, 4.0]; + const a = tf.tensor2d(aValues, [4, 1]); + const c = tf.tensor2d(cValues, [4, 1]); + + const r = tf.atan2(a, c); + const expected = []; + + for (let i = 0; i < a.size; i++) { + expected[i] = Math.atan2(aValues[i], cValues[i]); + } + expectArraysClose(await r.data(), expected); + }); + it('gradient: Scalar', async () => { const a = tf.scalar(5); const b = tf.scalar(2);