Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/tensorflow/tfjs into jax2…
Browse files Browse the repository at this point in the history
…tfjs
  • Loading branch information
marcvanzee committed Aug 16, 2022
2 parents 15f4a77 + b02de70 commit 41d12f0
Show file tree
Hide file tree
Showing 20 changed files with 665 additions and 555 deletions.
42 changes: 29 additions & 13 deletions e2e/benchmarks/browserstack-benchmark/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,35 @@ async function benchmarkAll(config) {

for (backend of benchmarkInfo.backend) {
for (model of benchmarkInfo.model) {
console.log(
`\nRunning ${model} model benchmarks over ${backend} backend...`);
const result = await benchmark({
'benchmark': {
'model': model,
'numRuns': benchmarkInfo.numRuns,
'backend': backend,
'codeSnippet': benchmarkInfo.codeSnippet || '',
'setupCodeSnippetEnv': benchmarkInfo.setupCodeSnippetEnv || ''
},
'browsers': config.browsers
});
allResults.push(result);
if (model === 'codeSnippet') {
for (codeSnippetPair of benchmarkInfo.codeSnippets) {
console.log(
`\nRunning codeSnippet benchmarks over ${backend} backend...`);
const result = await benchmark({
'benchmark': {
'model': model,
'numRuns': benchmarkInfo.numRuns,
'backend': backend,
'codeSnippet': codeSnippetPair.codeSnippet || '',
'setupCodeSnippetEnv': codeSnippetPair.setupCodeSnippetEnv || ''
},
'browsers': config.browsers
});
allResults.push(result);
}
} else {
console.log(
`\nRunning ${model} model benchmarks over ${backend} backend...`);
const result = await benchmark({
'benchmark': {
'model': model,
'numRuns': benchmarkInfo.numRuns,
'backend': backend
},
'browsers': config.browsers
});
allResults.push(result);
}
}
}
console.log('\nAll benchmarks complete!');
Expand Down
78 changes: 39 additions & 39 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
export enum BinaryOpType {
MUL,
ADD,
ATAN2,
SUB,
DIV,
EQUAL,
Expand All @@ -37,6 +38,31 @@ export enum BinaryOpType {
COMPLEX_MULTIPLY_IMAG
}

const CHECK_NAN_SNIPPET = `
if (isnan(a)) { return a; }
if (isnan(b)) { return b; }
`;

const CHECK_NAN_SNIPPET_VEC4_INNER = `
if (isNaN.r) {
resultTemp.r = valueForNaN;
}
if (isNaN.g) {
resultTemp.g = valueForNaN;
}
if (isNaN.b) {
resultTemp.b = valueForNaN;
}
if (isNaN.a) {
resultTemp.a = valueForNaN;
}
`;

const CHECK_NAN_SNIPPET_VEC4 = `
let isNaN = isnanVec4(a) | isnanVec4(b);
${CHECK_NAN_SNIPPET_VEC4_INNER}
`;

const ADD = 'return a + b;';
// (Ar + Ai)(Br + Bi) =
// ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
Expand All @@ -61,24 +87,6 @@ const LESS_EQUAL_VEC4 = 'return vec4<f32>(a <= b);';
const LOGICAL_AND = 'return f32(f32(a) >= 1.0 && f32(b) >= 1.0);';
const LOGICAL_AND_VEC4 = `return (vec4<f32>(a >= vec4<f32>(1.0)) *
vec4<f32>(b >= vec4<f32>(1.0)));`;
const CHECK_NAN_SNIPPET = `
if (isnan(a)) { return a; }
if (isnan(b)) { return b; }
`;
const CHECK_NAN_SNIPPET_VEC4 = `
if (isNaN.r) {
resultTemp.r = uniforms.NAN;
}
if (isNaN.g) {
resultTemp.g = uniforms.NAN;
}
if (isNaN.b) {
resultTemp.b = uniforms.NAN;
}
if (isNaN.a) {
resultTemp.a = uniforms.NAN;
}
`;
const INT_DIV = `
let s = sign(a) * sign(b);
let ia = i32(round(a));
Expand Down Expand Up @@ -116,23 +124,11 @@ const NOT_EQUAL = `
return f32(a != b);
`;
const NOT_EQUAL_VEC4 = `
var result = vec4<f32>(a != b);
var isANaN = isnanVec4(a);
var isBNaN = isnanVec4(b);
if (isANaN.r || isBNaN.r) {
result.r = 1.0;
}
if (isANaN.g || isBNaN.g) {
result.g = 1.0;
}
if (isANaN.b || isBNaN.b) {
result.b = 1.0;
}
if (isANaN.a || isBNaN.a) {
result.a = 1.0;
}
var resultTemp = vec4<f32>(a != b);
let valueForNaN = 1.0;
${CHECK_NAN_SNIPPET_VEC4}
return result;
return resultTemp;
`;
const POW = `
if(a < 0.0 && floor(b) < b) {
Expand Down Expand Up @@ -167,7 +163,8 @@ const POW_VEC4 = `
resultTemp.a = 1.0;
}
let isNaN = a < vec4<f32>(0.0) & floor(b) < b;
${CHECK_NAN_SNIPPET_VEC4}
let valueForNaN = uniforms.NAN;
${CHECK_NAN_SNIPPET_VEC4_INNER}
return resultTemp;
`;

Expand All @@ -177,11 +174,12 @@ const PRELU_VEC4 = `
return (aLessThanZero * (b * a)) + ((vec4<f32>(1.0) - aLessThanZero) * a);
`;

function getMinMaxString(op: string, useVec4: boolean) {
function getBinaryWithNanString(
op: string, useVec4: boolean, valueForNaN = 'uniforms.NAN') {
const checkNanSnippet = useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET;
return useVec4 ? `
let valueForNaN = ${valueForNaN};
var resultTemp = vec4<f32>(${op}(a, b));
let isNaN = isnanVec4(a) | isnanVec4(b);
` + checkNanSnippet +
`
return resultTemp;
Expand All @@ -198,6 +196,8 @@ export function getBinaryOpString(
return MUL;
case BinaryOpType.ADD:
return ADD;
case BinaryOpType.ATAN2:
return getBinaryWithNanString('atan2', useVec4);
case BinaryOpType.SUB:
return SUB;
case BinaryOpType.DIV:
Expand All @@ -223,9 +223,9 @@ export function getBinaryOpString(
case BinaryOpType.PRELU:
return useVec4 ? PRELU_VEC4 : PRELU;
case BinaryOpType.MAX:
return getMinMaxString('max', useVec4);
return getBinaryWithNanString('max', useVec4);
case BinaryOpType.MIN:
return getMinMaxString('min', useVec4);
return getBinaryWithNanString('min', useVec4);
case BinaryOpType.POW:
return useVec4 ? POW_VEC4 : POW;
case BinaryOpType.COMPLEX_MULTIPLY_REAL:
Expand Down
12 changes: 5 additions & 7 deletions tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import {backend_util} from '@tensorflow/tfjs-core';

import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {makeMatMulPackedVec4Source} from './matmul_packed_vec4_webgpu';
import {makeMatMulPackedSource} from './matmul_packed_webgpu';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
import {WebGPUProgram} from './webgpu_program';
import {computeDispatch, computeWorkGroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util';

Expand Down Expand Up @@ -251,14 +250,13 @@ export class Conv2DMMProgram implements WebGPUProgram {
getUserCode(): string {
const matMulSource = this.isVec4 ?
makeMatMulPackedVec4Source(
this.elementsPerThread, this.tileAOuter, this.tileBOuter,
this.tileInner, this.innerElementSize, !this.isChannelsLast) :
this.elementsPerThread, this.workGroupSize, !this.isChannelsLast,
this.tileInner) :
makeMatMulPackedSource(
this.elementsPerThread, this.workGroupSize, !this.isChannelsLast,
this.tileInner);
const elementsSize = this.isVec4 ?
[this.isChannelsLast ? this.innerElementSize : 4, 4, 4] :
[1, 1, 1];
const elementsSize =
this.isVec4 ? [this.innerElementSize, 4, 4] : [1, 1, 1];
const userCode = `
${
conv2dCommonSnippet(
Expand Down
23 changes: 5 additions & 18 deletions tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

import {backend_util, util} from '@tensorflow/tfjs-core';
import {typeSnippet} from './activation_util';
import {makeMatMulPackedVec4Source} from './matmul_packed_vec4_webgpu';
import {makeMatMulPackedSource} from './matmul_packed_webgpu';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
import {WebGPUProgram} from './webgpu_program';
import {computeDispatch, computeWorkGroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util';

Expand Down Expand Up @@ -123,10 +122,6 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram {
'filterDims : vec2<i32>, pads : vec2<i32>, stride : vec2<i32>, outBackprop : vec4<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32,';
workGroupSize: [number, number, number];
elementsPerThread: [number, number, number];
tileAOuter: number;
tileBOuter: number;
tileInner: number;
innerElementSize: number;
isVec4?: boolean;

constructor(convInfo: backend_util.Conv2DInfo) {
Expand All @@ -148,24 +143,16 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram {
this.elementsPerThread);

if (this.isVec4) {
this.innerElementSize = 4;
this.variableTypes = ['vec4<f32>', 'f32'];
} else {
this.innerElementSize = this.elementsPerThread[0];
}
this.tileAOuter = this.workGroupSize[1] * this.elementsPerThread[1];
this.tileBOuter = this.workGroupSize[0] * this.elementsPerThread[0];
this.tileInner = Math.max(
this.workGroupSize[0] * this.innerElementSize, this.workGroupSize[1]);
this.shaderKey = `conv2DDerInputMM_${this.isVec4}_${
this.elementsPerThread}_${this.innerElementSize}`;

this.shaderKey =
`conv2DDerInputMM_${this.isVec4}_${this.elementsPerThread}`;
}

getUserCode(): string {
const matMulSource = this.isVec4 ?
makeMatMulPackedVec4Source(
this.elementsPerThread, this.tileAOuter, this.tileBOuter,
this.tileInner, this.innerElementSize) :
makeMatMulPackedVec4Source(this.elementsPerThread, this.workGroupSize) :
makeMatMulPackedSource(this.elementsPerThread, this.workGroupSize);
const userCode = `
${conv2dTransposeCommonSnippet(this.isVec4 ? 4 : 1)}
Expand Down
5 changes: 0 additions & 5 deletions tfjs-backend-webgpu/src/flags_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ ENV.registerFlag('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', () => 15);
*/
ENV.registerFlag('WEBGPU_CPU_FORWARD', () => true);

/**
* Thread register block size for matmul kernel.
*/
ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 4);

/**
* This flag is used to test different types of matmul programs.
*
Expand Down
28 changes: 28 additions & 0 deletions tfjs-backend-webgpu/src/kernels/Atan2.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Atan2, KernelConfig} from '@tensorflow/tfjs-core';
import {BinaryOpType} from '../binary_op_util';
import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';

export const atan2 = binaryKernelFunc({opType: BinaryOpType.ATAN2});

export const atan2Config: KernelConfig = {
kernelName: Atan2,
backendName: 'webgpu',
kernelFunc: atan2
};
20 changes: 2 additions & 18 deletions tfjs-backend-webgpu/src/kernels/BatchMatMul_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import {backend_util, broadcast_util, env, TensorInfo, util} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {MatMulPackedVec4Program} from '../matmul_packed_vec4_webgpu';
import {MatMulPackedProgram} from '../matmul_packed_webgpu';
import {MatMulReduceProgram} from '../matmul_reduce_webgpu';
import {MatMulSmallOutputSizeProgram} from '../matmul_small_output_size_webgpu';
Expand Down Expand Up @@ -93,9 +92,6 @@ export function batchMatMulImpl({
const batchDim = Math.max(batchDimA, batchDimB);
const batchAEqualOne = batchDimA === 1;
const batchBEqualOne = batchDimB === 1;
const useVec4 = ((innerShapeA % 4 === 0 && !transposeA) ||
(outerShapeA % 4 === 0 && transposeA)) &&
outerShapeB % 4 === 0 && !transposeB;

const inputs: TensorInfo[] = [a3d, b3d];
const dimensions = [
Expand Down Expand Up @@ -133,22 +129,12 @@ export function batchMatMulImpl({
(outerShapeB <= 16 &&
(outerShapeA <= 512 || innerShapeA >= 2 * outerShapeA))) {
matmulProgramType = MatMulProgramType.MatMulSmallOutputSizeProgram;
} else if (useVec4) {
// TODO: Currently we need to make sure that innerShapeA and outerShapeB
// are divisible by 4 since we use vec4 to get data. In future, we can
// remove this limitation by insert 0 to pack data.
matmulProgramType = MatMulProgramType.MatMulPackedVec4Program;
} else {
matmulProgramType = MatMulProgramType.MatMulPackedProgram;
}
}

switch (matmulProgramType) {
case MatMulProgramType.MatMulPackedVec4Program:
program = new MatMulPackedVec4Program(
a3dShape, outputShape, batchAEqualOne, batchBEqualOne, transposeA,
bias, activation, preluActivationWeights);
break;
case MatMulProgramType.MatMulReduceProgram:
program = new MatMulReduceProgram(
outputShape, batchAEqualOne, batchBEqualOne, transposeA, transposeB,
Expand Down Expand Up @@ -199,10 +185,8 @@ export function batchMatMulImpl({
break;
case MatMulProgramType.MatMulPackedProgram:
program = new MatMulPackedProgram(
a3dShape, outputShape,
env().get('WEBGPU_MATMUL_WORK_PER_THREAD') as number, batchAEqualOne,
batchBEqualOne, transposeA, transposeB, bias, activation,
preluActivationWeights);
a3dShape, outputShape, batchAEqualOne, batchBEqualOne, transposeA,
transposeB, bias, activation, preluActivationWeights);
break;
default:
throw new Error(`Unsupported MatMulProgramType ${matmulProgramType}.`);
Expand Down
29 changes: 29 additions & 0 deletions tfjs-backend-webgpu/src/kernels/IsNaN.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {IsNan, KernelConfig} from '@tensorflow/tfjs-core';
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
import {UnaryOpType} from '../unary_op_util';

export const isNaN =
unaryKernelFunc({opType: UnaryOpType.IS_NAN, dtype: 'bool'});

export const isNaNConfig: KernelConfig = {
kernelName: IsNan,
backendName: 'webgpu',
kernelFunc: isNaN
};
Loading

0 comments on commit 41d12f0

Please sign in to comment.