Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webgpu: Support any component buffer #7426

Merged
merged 4 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 1 addition & 15 deletions tfjs-backend-webgpu/src/activation_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,7 @@ import {backend_util} from '@tensorflow/tfjs-core';

import {BinaryOpType, getBinaryOpString} from './binary_op_util';
import {getUnaryOpString, UnaryOpType} from './unary_op_util';

export const typeSnippet = (component: number) => {
switch (component) {
case 1:
return 'f32';
case 2:
return 'vec2<f32>';
case 3:
return 'vec3<f32>';
case 4:
return 'vec4<f32>';
default:
throw new Error(`${component}-component is not supported.`);
}
};
import {typeSnippet} from './webgpu_program';

export function activationFnSnippet(
activation: backend_util.Activation, hasPreluActivationWeights = false,
Expand Down
7 changes: 5 additions & 2 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,11 @@ export class WebGPUBackend extends KernelBackend {
programUniform.push({type: uniformsType, data: strides});
if (program.size) {
const size = util.sizeFromShape(program.outputShape);
programUniform.push(
{type: uniformsType, data: [program.isVec4 ? size / 4 : size]});
programUniform.push({
type: uniformsType,
data:
[program.outputComponent ? size / program.outputComponent : size]
});
}
}

Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/binary_op_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {computeDispatch, flatDispatchLayout} from './webgpu_util';
export class BinaryOpProgram implements WebGPUProgram {
dispatch: [number, number, number];
dispatchLayout: {x: number[]};
outputComponent: number;
isVec4: boolean;
op: BinaryOpType;
outputShape: number[];
Expand Down Expand Up @@ -65,6 +66,7 @@ export class BinaryOpProgram implements WebGPUProgram {
if (util.arraysEqual(aShape, bShape) &&
util.sizeFromShape(aShape) % 4 === 0) {
this.isVec4 = true;
this.outputComponent = 4;
this.type = 'vec4';
this.workPerThread = 4;
} else {
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/clip_vec4_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export class ClipVec4Program implements WebGPUProgram {
dispatch: [number, number, number];
workPerThread = 4;
workgroupSize: [number, number, number] = [64, 1, 1];
isVec4 = true;
outputComponent = 4;
size = true;

constructor(outputShape: number[]) {
Expand Down
16 changes: 9 additions & 7 deletions tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import {backend_util} from '@tensorflow/tfjs-core';

import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {activationFnSnippet, biasActivationSnippet} from './activation_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
import {WebGPUProgram} from './webgpu_program';
import {typeSnippet, WebGPUProgram} from './webgpu_program';
import {computeDispatch, computeWorkgroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util';

function conv2dCommonSnippet(
Expand Down Expand Up @@ -159,7 +159,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
variableNames = ['x', 'W'];
variableTypes: string[];
variableComponents: number[];
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In future, can we change this to inputNames and inputComponents?

uniforms =
`filterDims : vec2<i32>, pads : vec2<i32>, strides : vec2<i32>, dilations : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32,`;
workgroupSize: [number, number, number];
Expand All @@ -176,6 +176,7 @@ export class Conv2DMMProgram implements WebGPUProgram {
tileInner: number;
innerElementSize: number;
isVec4?: boolean;
outputComponent: number;
private sequentialAccessByThreads: boolean;

constructor(
Expand All @@ -202,22 +203,23 @@ export class Conv2DMMProgram implements WebGPUProgram {
this.elementsPerThread);

if (this.isVec4) {
this.outputComponent = 4;
if (this.isChannelsLast && convInfo.inChannels % 4 !== 0) {
this.innerElementSize = 3;
this.variableTypes = ['f32', 'vec4<f32>'];
this.variableComponents = [1, 4];
} else {
this.innerElementSize = 4;
this.variableTypes = ['vec4<f32>', 'vec4<f32>'];
this.variableComponents = [4, 4];
}

if (addBias) {
this.variableNames.push('bias');
this.variableTypes.push('vec4<f32>');
this.variableComponents.push(4);
}

if (hasPreluActivationWeights) {
this.variableNames.push('preluActivationWeights');
this.variableTypes.push('vec4<f32>');
this.variableComponents.push(4);
}
} else {
this.innerElementSize = this.elementsPerThread[0];
Expand Down
10 changes: 6 additions & 4 deletions tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/

import {backend_util, util} from '@tensorflow/tfjs-core';
import {typeSnippet} from './activation_util';

import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
import {WebGPUProgram} from './webgpu_program';
import {typeSnippet, WebGPUProgram} from './webgpu_program';
import {computeDispatch, computeWorkgroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util';

function conv2dTransposeCommonSnippet(innerElementSize = 4) {
Expand Down Expand Up @@ -117,12 +117,13 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram {
dispatchLayout: {x: number[], y: number[], z: number[]};
dispatch: [number, number, number];
variableNames = ['x', 'W'];
variableTypes: string[];
variableComponents: number[];
uniforms =
'filterDims : vec2<i32>, pads : vec2<i32>, strides : vec2<i32>, outBackprop : vec4<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32,';
workgroupSize: [number, number, number];
elementsPerThread: [number, number, number];
isVec4?: boolean;
outputComponent: number;

constructor(convInfo: backend_util.Conv2DInfo) {
this.outputShape = convInfo.inShape;
Expand All @@ -143,7 +144,8 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram {
this.elementsPerThread);

if (this.isVec4) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to clean up isVec4 here.

this.variableTypes = ['vec4<f32>', 'f32'];
this.outputComponent = 4;
this.variableComponents = [4, 1];
}

this.shaderKey =
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/conv_backprop_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export class Conv2DDerInputProgram implements WebGPUProgram {
size = false;
isVec4 = false;
workPerThread = 1;
outputComponent: number;

constructor(convInfo: backend_util.Conv2DInfo) {
this.outputShape = convInfo.inShape;
Expand All @@ -41,6 +42,7 @@ export class Conv2DDerInputProgram implements WebGPUProgram {
if (this.isVec4) {
// TODO: Expand to any value.
this.workPerThread = 2;
this.outputComponent = 4;
this.workgroupSize = [4, 4, 4];
this.dispatchLayout = {x: [3], y: [2], z: [0, 1]};
this.dispatch = computeDispatch(
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram {
addBias: boolean;
activation: backend_util.Activation;
hasPreluActivation: boolean;
isVec4 = true;
outputComponent = 4;

constructor(
convInfo: backend_util.Conv2DInfo, addBias = false,
Expand Down
7 changes: 5 additions & 2 deletions tfjs-backend-webgpu/src/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
*/

import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core';
import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';

import {activationFnSnippet, biasActivationSnippet} from './activation_util';
import {getMainHeaderString as main, typeSnippet, WebGPUProgram} from './webgpu_program';
import {computeDispatch, computeWorkgroupInfoForMatMul} from './webgpu_util';

export function matMulReadFnSource(
Expand Down Expand Up @@ -509,6 +510,7 @@ export class MatMulPackedProgram implements WebGPUProgram {
tileInner: number;
isVectorA: boolean;
isVec4: boolean;
outputComponent: number;
private sequentialAccessByThreads: boolean;

constructor(
Expand All @@ -523,6 +525,7 @@ export class MatMulPackedProgram implements WebGPUProgram {
this.isVec4 = ((dimInner % 4 === 0 && !transposeA) ||
(outputShape[1] % 4 === 0 && transposeA)) &&
outputShape[2] % 4 === 0 && !transposeB;
this.outputComponent = this.isVec4 ? 4 : 1;
this.isVectorA = outputShape[1] === 1 && !transposeA;

if (!this.isVec4 && this.isVectorA) {
Expand Down
30 changes: 15 additions & 15 deletions tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core';

import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util';
import {activationFnSnippet, biasActivationSnippet} from './activation_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source, matMulReadFnSource} from './matmul_packed_webgpu';
import {atomicAddSnippet} from './shader_util';
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {getMainHeaderString as main, typeSnippet, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

export class MatMulSplitKProgram implements WebGPUProgram {
Expand All @@ -35,7 +35,7 @@ export class MatMulSplitKProgram implements WebGPUProgram {
transposeA: boolean;
transposeB: boolean;
atomic = true;
isVec4 = false;
outputComponent: number;
splitedDimInner = 128;

constructor(
Expand All @@ -46,12 +46,12 @@ export class MatMulSplitKProgram implements WebGPUProgram {
() => 'MatMulSplitKProgram only supports batch = 1.');
this.outputShape = outputShape;
this.dispatchLayout = {x: [2], y: [1], z: [0, 3]};
this.isVec4 = (transposeA && this.outputShape[1] % 4 === 0 ||
!transposeA && dimInner % 4 === 0) &&
const isVec4 = (transposeA && this.outputShape[1] % 4 === 0 ||
!transposeA && dimInner % 4 === 0) &&
this.outputShape[2] % 4 === 0;
this.elementsPerThread = [4, 4, this.splitedDimInner];

if (!this.isVec4) {
this.outputComponent = isVec4 ? 4 : 1;
if (!isVec4) {
if (this.outputShape[1] < 16) {
this.elementsPerThread[1] = 1;
}
Expand All @@ -71,11 +71,11 @@ export class MatMulSplitKProgram implements WebGPUProgram {
this.transposeA = transposeA;
this.transposeB = transposeB;
this.shaderKey = `matMulSplitK_${transposeA}_${transposeB}_${
this.elementsPerThread}_${this.isVec4}`;
this.elementsPerThread}_${this.outputComponent}`;
}

getUserCode(): string {
const component = this.isVec4 ? 4 : 1;
const component = this.outputComponent;
const userCode = `
${
matMulReadFnSource(
Expand All @@ -97,12 +97,12 @@ export class MatMulSplitKProgram implements WebGPUProgram {
}
}
${
this.isVec4 ? makeMatMulPackedVec4Source(
this.elementsPerThread, this.workgroupSize,
this.transposeA, 32, true, this.splitedDimInner) :
makeMatMulPackedSource(
this.elementsPerThread, this.workgroupSize,
this.transposeA, 32, true, this.splitedDimInner)}
component === 4 ? makeMatMulPackedVec4Source(
this.elementsPerThread, this.workgroupSize,
this.transposeA, 32, true, this.splitedDimInner) :
makeMatMulPackedSource(
this.elementsPerThread, this.workgroupSize,
this.transposeA, 32, true, this.splitedDimInner)}
`;
return userCode;
}
Expand Down
5 changes: 3 additions & 2 deletions tfjs-backend-webgpu/src/scatter_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
*/

import {DataType} from '@tensorflow/tfjs-core';

import {atomicAddSnippet} from './shader_util';
import {getCoordsDataType, getMainHeaderString as main, mapToWgslTypes, WebGPUProgram} from './webgpu_program';
import {dataTypeToGPUType, getCoordsDataType, getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

export class ScatterProgram implements WebGPUProgram {
Expand Down Expand Up @@ -107,7 +108,7 @@ export class ScatterProgram implements WebGPUProgram {
flattenedIndex = flattenedIndex + indexInside * ${strideString};
}
let updateValue =
${mapToWgslTypes(this.type, false)}(${updatesSnippet});
${dataTypeToGPUType(this.type)}(${updatesSnippet});
let flatIndex = getOutputIndexFromCoords(${outCoordsString});

${
Expand Down
Loading