Skip to content

Commit

Permalink
webgpu: Support any component buffer (#7426)
Browse files Browse the repository at this point in the history
* webgpu: Support any component buffer

* address comments

* Fix build error
  • Loading branch information
qjia7 authored Mar 1, 2023
1 parent 8c835f0 commit 13d37d0
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 199 deletions.
16 changes: 1 addition & 15 deletions tfjs-backend-webgpu/src/activation_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,7 @@ import {backend_util} from '@tensorflow/tfjs-core';

import {BinaryOpType, getBinaryOpString} from './binary_op_util';
import {getUnaryOpString, UnaryOpType} from './unary_op_util';

export const typeSnippet = (component: number) => {
switch (component) {
case 1:
return 'f32';
case 2:
return 'vec2<f32>';
case 3:
return 'vec3<f32>';
case 4:
return 'vec4<f32>';
default:
throw new Error(`${component}-component is not supported.`);
}
};
import {typeSnippet} from './webgpu_program';

export function activationFnSnippet(
activation: backend_util.Activation, hasPreluActivationWeights = false,
Expand Down
7 changes: 5 additions & 2 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,11 @@ export class WebGPUBackend extends KernelBackend {
programUniform.push({type: uniformsType, data: strides});
if (program.size) {
const size = util.sizeFromShape(program.outputShape);
programUniform.push(
{type: uniformsType, data: [program.isVec4 ? size / 4 : size]});
programUniform.push({
type: uniformsType,
data:
[program.outputComponent ? size / program.outputComponent : size]
});
}
}

Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/binary_op_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {computeDispatch, flatDispatchLayout} from './webgpu_util';
export class BinaryOpProgram implements WebGPUProgram {
dispatch: [number, number, number];
dispatchLayout: {x: number[]};
outputComponent: number;
isVec4: boolean;
op: BinaryOpType;
outputShape: number[];
Expand Down Expand Up @@ -65,6 +66,7 @@ export class BinaryOpProgram implements WebGPUProgram {
if (util.arraysEqual(aShape, bShape) &&
util.sizeFromShape(aShape) % 4 === 0) {
this.isVec4 = true;
this.outputComponent = 4;
this.type = 'vec4';
this.workPerThread = 4;
} else {
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/clip_vec4_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export class ClipVec4Program implements WebGPUProgram {
dispatch: [number, number, number];
workPerThread = 4;
workgroupSize: [number, number, number] = [64, 1, 1];
isVec4 = true;
outputComponent = 4;
size = true;

constructor(outputShape: number[]) {
Expand Down
16 changes: 9 additions & 7 deletions tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import {backend_util} from '@tensorflow/tfjs-core';

import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {activationFnSnippet, biasActivationSnippet} from './activation_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
import {WebGPUProgram} from './webgpu_program';
import {typeSnippet, WebGPUProgram} from './webgpu_program';
import {computeDispatch, computeWorkgroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util';

function conv2dCommonSnippet(
Expand Down Expand Up @@ -159,7 +159,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
variableNames = ['x', 'W'];
variableTypes: string[];
variableComponents: number[];
uniforms =
`filterDims : vec2<i32>, pads : vec2<i32>, strides : vec2<i32>, dilations : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32,`;
workgroupSize: [number, number, number];
Expand All @@ -176,6 +176,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
tileInner: number;
innerElementSize: number;
isVec4?: boolean;
outputComponent: number;
private sequentialAccessByThreads: boolean;

constructor(
Expand All @@ -202,22 +203,23 @@ export class Conv2DMMProgram implements WebGPUProgram {
this.elementsPerThread);

if (this.isVec4) {
this.outputComponent = 4;
if (this.isChannelsLast && convInfo.inChannels % 4 !== 0) {
this.innerElementSize = 3;
this.variableTypes = ['f32', 'vec4<f32>'];
this.variableComponents = [1, 4];
} else {
this.innerElementSize = 4;
this.variableTypes = ['vec4<f32>', 'vec4<f32>'];
this.variableComponents = [4, 4];
}

if (addBias) {
this.variableNames.push('bias');
this.variableTypes.push('vec4<f32>');
this.variableComponents.push(4);
}

if (hasPreluActivationWeights) {
this.variableNames.push('preluActivationWeights');
this.variableTypes.push('vec4<f32>');
this.variableComponents.push(4);
}
} else {
this.innerElementSize = this.elementsPerThread[0];
Expand Down
10 changes: 6 additions & 4 deletions tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/

import {backend_util, util} from '@tensorflow/tfjs-core';
import {typeSnippet} from './activation_util';

import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
import {WebGPUProgram} from './webgpu_program';
import {typeSnippet, WebGPUProgram} from './webgpu_program';
import {computeDispatch, computeWorkgroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util';

function conv2dTransposeCommonSnippet(innerElementSize = 4) {
Expand Down Expand Up @@ -117,12 +117,13 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram {
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
variableNames = ['x', 'W'];
variableTypes: string[];
variableComponents: number[];
uniforms =
'filterDims : vec2<i32>, pads : vec2<i32>, strides : vec2<i32>, outBackprop : vec4<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32,';
workgroupSize: [number, number, number];
elementsPerThread: [number, number, number];
isVec4?: boolean;
outputComponent: number;

constructor(convInfo: backend_util.Conv2DInfo) {
this.outputShape = convInfo.inShape;
Expand All @@ -143,7 +144,8 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram {
this.elementsPerThread);

if (this.isVec4) {
this.variableTypes = ['vec4<f32>', 'f32'];
this.outputComponent = 4;
this.variableComponents = [4, 1];
}

this.shaderKey =
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/conv_backprop_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export class Conv2DDerInputProgram implements WebGPUProgram {
size = false;
isVec4 = false;
workPerThread = 1;
outputComponent: number;

constructor(convInfo: backend_util.Conv2DInfo) {
this.outputShape = convInfo.inShape;
Expand All @@ -41,6 +42,7 @@ export class Conv2DDerInputProgram implements WebGPUProgram {
if (this.isVec4) {
// TODO: Expand to any value.
this.workPerThread = 2;
this.outputComponent = 4;
this.workgroupSize = [4, 4, 4];
this.dispatchLayout = {x: [3], y: [2], z: [0, 1]};
this.dispatch = computeDispatch(
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram {
addBias: boolean;
activation: backend_util.Activation;
hasPreluActivation: boolean;
isVec4 = true;
outputComponent = 4;

constructor(
convInfo: backend_util.Conv2DInfo, addBias = false,
Expand Down
7 changes: 5 additions & 2 deletions tfjs-backend-webgpu/src/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
*/

import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core';
import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';

import {activationFnSnippet, biasActivationSnippet} from './activation_util';
import {getMainHeaderString as main, typeSnippet, WebGPUProgram} from './webgpu_program';
import {computeDispatch, computeWorkgroupInfoForMatMul} from './webgpu_util';

export function matMulReadFnSource(
Expand Down Expand Up @@ -509,6 +510,7 @@ export class MatMulPackedProgram implements WebGPUProgram {
tileInner: number;
isVectorA: boolean;
isVec4: boolean;
outputComponent: number;
private sequentialAccessByThreads: boolean;

constructor(
Expand All @@ -523,6 +525,7 @@ export class MatMulPackedProgram implements WebGPUProgram {
this.isVec4 = ((dimInner % 4 === 0 && !transposeA) ||
(outputShape[1] % 4 === 0 && transposeA)) &&
outputShape[2] % 4 === 0 && !transposeB;
this.outputComponent = this.isVec4 ? 4 : 1;
this.isVectorA = outputShape[1] === 1 && !transposeA;

if (!this.isVec4 && this.isVectorA) {
Expand Down
30 changes: 15 additions & 15 deletions tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core';

import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {activationFnSnippet, biasActivationSnippet} from './activation_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source, matMulReadFnSource} from './matmul_packed_webgpu';
import {atomicAddSnippet} from './shader_util';
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {getMainHeaderString as main, typeSnippet, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

export class MatMulSplitKProgram implements WebGPUProgram {
Expand All @@ -35,7 +35,7 @@ export class MatMulSplitKProgram implements WebGPUProgram {
transposeA: boolean;
transposeB: boolean;
atomic = true;
isVec4 = false;
outputComponent: number;
splitedDimInner = 128;

constructor(
Expand All @@ -46,12 +46,12 @@ export class MatMulSplitKProgram implements WebGPUProgram {
() => 'MatMulSplitKProgram only supports batch = 1.');
this.outputShape = outputShape;
this.dispatchLayout = {x: [2], y: [1], z: [0, 3]};
this.isVec4 = (transposeA && this.outputShape[1] % 4 === 0 ||
!transposeA && dimInner % 4 === 0) &&
const isVec4 = (transposeA && this.outputShape[1] % 4 === 0 ||
!transposeA && dimInner % 4 === 0) &&
this.outputShape[2] % 4 === 0;
this.elementsPerThread = [4, 4, this.splitedDimInner];

if (!this.isVec4) {
this.outputComponent = isVec4 ? 4 : 1;
if (!isVec4) {
if (this.outputShape[1] < 16) {
this.elementsPerThread[1] = 1;
}
Expand All @@ -71,11 +71,11 @@ export class MatMulSplitKProgram implements WebGPUProgram {
this.transposeA = transposeA;
this.transposeB = transposeB;
this.shaderKey = `matMulSplitK_${transposeA}_${transposeB}_${
this.elementsPerThread}_${this.isVec4}`;
this.elementsPerThread}_${this.outputComponent}`;
}

getUserCode(): string {
const component = this.isVec4 ? 4 : 1;
const component = this.outputComponent;
const userCode = `
${
matMulReadFnSource(
Expand All @@ -97,12 +97,12 @@ export class MatMulSplitKProgram implements WebGPUProgram {
}
}
${
this.isVec4 ? makeMatMulPackedVec4Source(
this.elementsPerThread, this.workgroupSize,
this.transposeA, 32, true, this.splitedDimInner) :
makeMatMulPackedSource(
this.elementsPerThread, this.workgroupSize,
this.transposeA, 32, true, this.splitedDimInner)}
component === 4 ? makeMatMulPackedVec4Source(
this.elementsPerThread, this.workgroupSize,
this.transposeA, 32, true, this.splitedDimInner) :
makeMatMulPackedSource(
this.elementsPerThread, this.workgroupSize,
this.transposeA, 32, true, this.splitedDimInner)}
`;
return userCode;
}
Expand Down
5 changes: 3 additions & 2 deletions tfjs-backend-webgpu/src/scatter_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
*/

import {DataType} from '@tensorflow/tfjs-core';

import {atomicAddSnippet} from './shader_util';
import {getCoordsDataType, getMainHeaderString as main, mapToWgslTypes, WebGPUProgram} from './webgpu_program';
import {dataTypeToGPUType, getCoordsDataType, getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

export class ScatterProgram implements WebGPUProgram {
Expand Down Expand Up @@ -107,7 +108,7 @@ export class ScatterProgram implements WebGPUProgram {
flattenedIndex = flattenedIndex + indexInside * ${strideString};
}
let updateValue =
${mapToWgslTypes(this.type, false)}(${updatesSnippet});
${dataTypeToGPUType(this.type)}(${updatesSnippet});
let flatIndex = getOutputIndexFromCoords(${outCoordsString});
${
Expand Down
Loading

0 comments on commit 13d37d0

Please sign in to comment.