From bffb70365790228ded5296054232fab5a3f83f7d Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Thu, 21 Jul 2022 15:12:26 +0800 Subject: [PATCH 1/3] webgpu: Support any component buffer --- tfjs-backend-webgpu/src/activation_util.ts | 16 +- tfjs-backend-webgpu/src/backend_webgpu.ts | 7 +- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 2 + tfjs-backend-webgpu/src/clip_vec4_webgpu.ts | 2 +- tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts | 16 +- .../src/conv_backprop_mm_webgpu.ts | 10 +- .../src/conv_backprop_webgpu.ts | 2 + .../src/depthwise_conv2d_vec4_webgpu.ts | 2 +- .../src/matmul_packed_webgpu.ts | 7 +- .../src/matmul_splitK_webgpu.ts | 7 +- tfjs-backend-webgpu/src/scatter_webgpu.ts | 2 +- tfjs-backend-webgpu/src/webgpu_program.ts | 256 ++++++++---------- 12 files changed, 145 insertions(+), 184 deletions(-) diff --git a/tfjs-backend-webgpu/src/activation_util.ts b/tfjs-backend-webgpu/src/activation_util.ts index a11d5b93ef7..bda3917a29e 100644 --- a/tfjs-backend-webgpu/src/activation_util.ts +++ b/tfjs-backend-webgpu/src/activation_util.ts @@ -19,21 +19,7 @@ import {backend_util} from '@tensorflow/tfjs-core'; import {BinaryOpType, getBinaryOpString} from './binary_op_util'; import {getUnaryOpString, UnaryOpType} from './unary_op_util'; - -export const typeSnippet = (component: number) => { - switch (component) { - case 1: - return 'f32'; - case 2: - return 'vec2'; - case 3: - return 'vec3'; - case 4: - return 'vec4'; - default: - throw new Error(`${component}-component is not supported.`); - } -}; +import {typeSnippet} from './webgpu_program'; export function activationFnSnippet( activation: backend_util.Activation, hasPreluActivationWeights = false, diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index ab78ef7777a..ddeb5f61110 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -784,8 +784,11 @@ export class WebGPUBackend extends KernelBackend { programUniform.push({type: uniformsType, data: strides}); if (program.size) { const size = util.sizeFromShape(program.outputShape); - programUniform.push( - {type: uniformsType, data: [program.isVec4 ? size / 4 : size]}); + programUniform.push({ + type: uniformsType, + data: + [program.outputComponent ? size / program.outputComponent : size] + }); } } diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 891785d30d0..0fb1bf46b37 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -24,6 +24,7 @@ import {computeDispatch, flatDispatchLayout} from './webgpu_util'; export class BinaryOpProgram implements WebGPUProgram { dispatch: [number, number, number]; dispatchLayout: {x: number[]}; + outputComponent = 1; isVec4: boolean; op: BinaryOpType; outputShape: number[]; @@ -65,6 +66,7 @@ export class BinaryOpProgram implements WebGPUProgram { if (util.arraysEqual(aShape, bShape) && util.sizeFromShape(aShape) % 4 === 0) { this.isVec4 = true; + this.outputComponent = 4; this.type = 'vec4'; this.workPerThread = 4; } else { diff --git a/tfjs-backend-webgpu/src/clip_vec4_webgpu.ts b/tfjs-backend-webgpu/src/clip_vec4_webgpu.ts index a935a599f37..1cb5a2a0d2c 100644 --- a/tfjs-backend-webgpu/src/clip_vec4_webgpu.ts +++ b/tfjs-backend-webgpu/src/clip_vec4_webgpu.ts @@ -27,7 +27,7 @@ export class ClipVec4Program implements WebGPUProgram { dispatch: [number, number, number]; workPerThread = 4; workgroupSize: [number, number, number] = [64, 1, 1]; - isVec4 = true; + outputComponent = 4; size = true; constructor(outputShape: number[]) { diff --git a/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts b/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts index a1086bb35e3..77be2d8a9b5 100644 --- a/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts @@ -17,9 +17,9 @@ import {backend_util} from '@tensorflow/tfjs-core'; -import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {activationFnSnippet, biasActivationSnippet} from './activation_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; -import {WebGPUProgram} from './webgpu_program'; +import {typeSnippet, WebGPUProgram} from './webgpu_program'; import {computeDispatch, computeWorkgroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util'; function conv2dCommonSnippet( @@ -159,7 +159,7 @@ export class Conv2DMMProgram implements WebGPUProgram { dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; variableNames = ['x', 'W']; - variableTypes: string[]; + variableComponents: number[]; uniforms = `filterDims : vec2, pads : vec2, strides : vec2, dilations : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32,`; workgroupSize: [number, number, number]; @@ -176,6 +176,7 @@ export class Conv2DMMProgram implements WebGPUProgram { tileInner: number; innerElementSize: number; isVec4?: boolean; + outputComponent = 1; private sequentialAccessByThreads: boolean; constructor( @@ -202,22 +203,23 @@ export class Conv2DMMProgram implements WebGPUProgram { this.elementsPerThread); if (this.isVec4) { + this.outputComponent = 4; if (this.isChannelsLast && convInfo.inChannels % 4 !== 0) { this.innerElementSize = 3; - this.variableTypes = ['f32', 'vec4']; + this.variableComponents = [1, 4]; } else { this.innerElementSize = 4; - this.variableTypes = ['vec4', 'vec4']; + this.variableComponents = [4, 4]; } if (addBias) { this.variableNames.push('bias'); - this.variableTypes.push('vec4'); + this.variableComponents.push(4); } if (hasPreluActivationWeights) { this.variableNames.push('preluActivationWeights'); - this.variableTypes.push('vec4'); + this.variableComponents.push(4); } } else { this.innerElementSize = this.elementsPerThread[0]; diff --git a/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts b/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts index ea95df6c052..458fb5d756e 100644 --- a/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts @@ -16,9 +16,9 @@ */ import {backend_util, util} from '@tensorflow/tfjs-core'; -import {typeSnippet} from './activation_util'; + import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; -import {WebGPUProgram} from './webgpu_program'; +import {typeSnippet, WebGPUProgram} from './webgpu_program'; import {computeDispatch, computeWorkgroupSizeForConv2d, computeWorkPerThreadForConv2d} from './webgpu_util'; function conv2dTransposeCommonSnippet(innerElementSize = 4) { @@ -117,12 +117,13 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram { dispatchLayout: {x: number[], y: number[], z: number[]}; dispatch: [number, number, number]; variableNames = ['x', 'W']; - variableTypes: string[]; + variableComponents: number[]; uniforms = 'filterDims : vec2, pads : vec2, strides : vec2, outBackprop : vec4, dimAOuter : i32, dimBOuter : i32, dimInner : i32,'; workgroupSize: [number, number, number]; elementsPerThread: [number, number, number]; isVec4?: boolean; + outputComponent = 1; constructor(convInfo: backend_util.Conv2DInfo) { this.outputShape = convInfo.inShape; @@ -143,7 +144,8 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram { this.elementsPerThread); if (this.isVec4) { - this.variableTypes = ['vec4', 'f32']; + this.outputComponent = 4; + this.variableComponents = [4, 1]; } this.shaderKey = diff --git a/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts b/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts index e87ce221cab..958bd915c2a 100644 --- a/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts +++ b/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts @@ -32,6 +32,7 @@ export class Conv2DDerInputProgram implements WebGPUProgram { size = false; isVec4 = false; workPerThread = 1; + outputComponent = 1; constructor(convInfo: backend_util.Conv2DInfo) { this.outputShape = convInfo.inShape; @@ -41,6 +42,7 @@ export class Conv2DDerInputProgram implements WebGPUProgram { if (this.isVec4) { // TODO: Expand to any value. this.workPerThread = 2; + this.outputComponent = 4; this.workgroupSize = [4, 4, 4]; this.dispatchLayout = {x: [3], y: [2], z: [0, 1]}; this.dispatch = computeDispatch( diff --git a/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts b/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts index bba9197e5b9..d5b310d3ecf 100644 --- a/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts +++ b/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts @@ -33,7 +33,7 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram { addBias: boolean; activation: backend_util.Activation; hasPreluActivation: boolean; - isVec4 = true; + outputComponent = 4; constructor( convInfo: backend_util.Conv2DInfo, addBias = false, diff --git a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts index 69a8a25b240..83cd21be764 100644 --- a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts @@ -16,8 +16,9 @@ */ import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core'; -import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; -import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; + +import {activationFnSnippet, biasActivationSnippet} from './activation_util'; +import {getMainHeaderString as main, typeSnippet, WebGPUProgram} from './webgpu_program'; import {computeDispatch, computeWorkgroupInfoForMatMul} from './webgpu_util'; export function matMulReadFnSource( @@ -509,6 +510,7 @@ export class MatMulPackedProgram implements WebGPUProgram { tileInner: number; isVectorA: boolean; isVec4: boolean; + outputComponent: number; private sequentialAccessByThreads: boolean; constructor( @@ -523,6 +525,7 @@ export class MatMulPackedProgram implements WebGPUProgram { this.isVec4 = ((dimInner % 4 === 0 && !transposeA) || (outputShape[1] % 4 === 0 && transposeA)) && outputShape[2] % 4 === 0 && !transposeB; + this.outputComponent = this.isVec4 ? 4 : 1; this.isVectorA = outputShape[1] === 1 && !transposeA; if (!this.isVec4 && this.isVectorA) { diff --git a/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts b/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts index 1a41b469aa7..df0f013fff0 100644 --- a/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts @@ -17,10 +17,10 @@ import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core'; -import {activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {activationFnSnippet, biasActivationSnippet} from './activation_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source, matMulReadFnSource} from './matmul_packed_webgpu'; import {atomicAddSnippet} from './shader_util'; -import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {getMainHeaderString as main, typeSnippet, WebGPUProgram} from './webgpu_program'; import {computeDispatch, flatDispatchLayout} from './webgpu_util'; export class MatMulSplitKProgram implements WebGPUProgram { @@ -36,6 +36,7 @@ export class MatMulSplitKProgram implements WebGPUProgram { transposeB: boolean; atomic = true; isVec4 = false; + outputComponent = 1; splitedDimInner = 128; constructor( @@ -50,7 +51,7 @@ export class MatMulSplitKProgram implements WebGPUProgram { !transposeA && dimInner % 4 === 0) && this.outputShape[2] % 4 === 0; this.elementsPerThread = [4, 4, this.splitedDimInner]; - + this.outputComponent = this.isVec4 ? 4 : 1; if (!this.isVec4) { if (this.outputShape[1] < 16) { this.elementsPerThread[1] = 1; diff --git a/tfjs-backend-webgpu/src/scatter_webgpu.ts b/tfjs-backend-webgpu/src/scatter_webgpu.ts index b878a5158d2..de4566380f4 100644 --- a/tfjs-backend-webgpu/src/scatter_webgpu.ts +++ b/tfjs-backend-webgpu/src/scatter_webgpu.ts @@ -107,7 +107,7 @@ export class ScatterProgram implements WebGPUProgram { flattenedIndex = flattenedIndex + indexInside * ${strideString}; } let updateValue = - ${mapToWgslTypes(this.type, false)}(${updatesSnippet}); + ${mapToWgslTypes(this.type)}(${updatesSnippet}); let flatIndex = getOutputIndexFromCoords(${outCoordsString}); ${ diff --git a/tfjs-backend-webgpu/src/webgpu_program.ts b/tfjs-backend-webgpu/src/webgpu_program.ts index 4d093fa56bd..5242fc01e07 100644 --- a/tfjs-backend-webgpu/src/webgpu_program.ts +++ b/tfjs-backend-webgpu/src/webgpu_program.ts @@ -28,7 +28,8 @@ export interface WebGPUProgram { // dispatch x,y,z dimensions. dispatchLayout: {x: number[], y?: number[], z?: number[]}; isFromPixels?: boolean; - isVec4?: boolean; + // By default, the output data component is 1. + outputComponent?: number; outputShape: number[]; // The unique key to distinguish different shader source code. shaderKey: string; @@ -36,10 +37,10 @@ export interface WebGPUProgram { size?: boolean; uniforms?: string; variableNames: string[]; - // Describe each variable's type and must have one-one mapping with - // variableNames. If not set, all variables type will be either f32 or - // vec4 based on isVec4 member. - variableTypes?: string[]; + // Describe each variable's component and must have one-one mapping with + // variableNames. If not set, all variables component will be same with output + // component member. + variableComponents?: number[]; // workgroupSize.x * workgroupSize.y * workgroupSize.z = the number of threads // in a thread group. Individual dimensions determines thread layout within // the group. @@ -67,6 +68,21 @@ export const compileProgram = return pipeline; }; +export const typeSnippet = (component: number, type = 'f') => { + switch (component) { + case 1: + return `${type}32`; + case 2: + return `vec2<${type}32>`; + case 3: + return `vec3<${type}32>`; + case 4: + return `vec4<${type}32>`; + default: + throw new Error(`${component}-component is not supported.`); + } +}; + export function getCoordsDataType(rank: number): string { if (rank <= 1) { return 'i32'; @@ -188,7 +204,7 @@ function makeShader( }; @group(0) @binding(0) var result: array<${ - mapToWgslTypes(outputData.dtype, program.isVec4)}>; + mapToWgslTypes(outputData.dtype, program.outputComponent)}>; @group(0) @binding(2) var uniforms: Uniform; `); const useGlobalIndex = isFlatDispatchLayout(program); @@ -234,15 +250,15 @@ function makeShader( } else { prefixSnippets.push(` @group(0) @binding(0) var result: array<${ - mapToWgslTypes(outputData.dtype, program.isVec4)}>; + mapToWgslTypes(outputData.dtype, program.outputComponent)}>; `); } program.variableNames.forEach((x, i) => { prefixSnippets.push(` @group(0) @binding(${1 + i}) var ${x}: array<${ - program.variableTypes ? - program.variableTypes[i] : - mapToWgslTypes(inputInfo[i].dtype, program.isVec4)}>; + program.variableComponents ? + mapToWgslTypes(inputInfo[i].dtype, program.variableComponents[i]) : + mapToWgslTypes(inputInfo[i].dtype, program.outputComponent)}>; `); }); @@ -262,8 +278,9 @@ function makeShader( getOutputIndexFromCoordsSnippet(outputData.shape.length) ]; if (!program.atomic) { - sources.push( - setOutputSnippet(outputData.shape, outputData.dtype, program.isVec4)); + sources.push(setOutputSnippet( + outputData.shape, outputData.dtype, + program.outputComponent ? program.outputComponent : 1)); } const inputSnippet = @@ -271,9 +288,9 @@ function makeShader( .map( (x, i) => getInputSnippet( x, outputData.shape, - program.variableTypes ? - (program.variableTypes[i] === 'vec4') : - program.isVec4, + program.variableComponents ? + program.variableComponents[i] : + program.outputComponent ? program.outputComponent : 1, program.dispatchLayout.x.length === outputData.shape.length)) .join('\n'); sources.push(inputSnippet); @@ -379,7 +396,8 @@ const isInfSnippet = ` type InputInfo = { dtype: DataType; shape: number[]; name: string; }; -export type WGSLDataType = 'f32'|'i32'|'vec4'|'vec4'|'vec4'; +export type WGSLDataType = 'f32'|'vec2'|'vec3'|'vec4'|'i32'| + 'vec2'|'vec3'|'vec4'; /** * Derives logical coordinates from a flat index. Performs integer division @@ -432,7 +450,7 @@ function getCoordsFromIndexSnippet(shape: number[]): string { } function getInputAtCoordsSnippet( - inputInfo: InputInfo, isVec4: boolean): string { + inputInfo: InputInfo, component: number): string { const texName = inputInfo.name; const rank = inputInfo.shape.length; const type = getCoordsDataType(rank); @@ -441,17 +459,9 @@ function getInputAtCoordsSnippet( const inputs = dims.map(d => `${d} : i32`).join(', '); if (rank < 1) { - if (isVec4) { - return ` - fn ${funcName}() -> vec4 { - return vec4(${texName}[0]); - } - `; - } - return ` - fn ${funcName}() ->f32 { - return f32(${texName}[0]); + fn ${funcName}() -> ${typeSnippet(component)} { + return ${typeSnippet(component)}(${texName}[0]); } `; } @@ -463,27 +473,17 @@ function getInputAtCoordsSnippet( rankStr = '1D'; } - if (isVec4) { - return ` - fn ${funcName}(${inputs}) -> vec4 { - return vec4(${texName}[getIndexFromCoords${rankStr}(${type}(${ - dims.join(',')}), - ${shapeStr}) / 4]); - } - `; - } - return ` - fn ${funcName}(${inputs}) -> f32 { - return f32(${texName}[getIndexFromCoords${rankStr}(${type}(${ - dims.join(',')}), - ${shapeStr})]); + fn ${funcName}(${inputs}) -> ${typeSnippet(component)} { + return ${typeSnippet(component)}(${texName}[getIndexFromCoords${ + rankStr}(${type}(${dims.join(',')}), + ${shapeStr})${component === 1 ? '' : ` / ${component}`}]); } `; } function getInputByOutputSnippet( - inputInfo: InputInfo, outShape: number[], isVec4: boolean, + inputInfo: InputInfo, outShape: number[], component: number, isFlatDispatchLayout: boolean): string { const texName = inputInfo.name; const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); @@ -498,29 +498,17 @@ function getInputByOutputSnippet( // directly use |gl_GlobalInvocationID.x| as the index and don't need coords // conversion between these two shapes. if (util.arraysEqual(inputInfo.shape, outShape) && isFlatDispatchLayout) { - if (isVec4) { - return ` - fn ${funcName}Index(globalIndex : i32) -> vec4 { - return vec4(${texName}[globalIndex]); - } - - fn ${funcName}Coords(coords : ${type}) -> vec4 { - return vec4(${texName}[${ - outRank > 1 ? 'getOutputIndexFromCoords(coords)' : 'coords'} / 4]); - } - `; - } else { - return ` - fn ${funcName}Index(globalIndex : i32) -> f32 { - return f32(${texName}[globalIndex]); + return ` + fn ${funcName}Index(globalIndex : i32) -> ${typeSnippet(component)} { + return ${typeSnippet(component)}(${texName}[globalIndex]); } - fn ${funcName}Coords(coords : ${type}) -> f32 { - return f32(${texName}[${ - outRank > 1 ? 'getOutputIndexFromCoords(coords)' : 'coords'}]); + fn ${funcName}Coords(coords : ${type}) -> ${typeSnippet(component)} { + return ${typeSnippet(component)}(${texName}[${ + outRank > 1 ? 'getOutputIndexFromCoords(coords)' : + 'coords'}${component === 1 ? '' : ` / ${component}`}]); } `; - } } const broadcastDims = @@ -530,23 +518,12 @@ function getInputByOutputSnippet( let coordsSnippet = ''; if (inRank === 0) { - if (isVec4) { - return ` - fn ${funcName}Index(globalIndex : i32) -> vec4 { - return get${texFuncSnippet}(); - } - - fn ${funcName}Coords(coords : ${type}) -> vec4 { - return get${texFuncSnippet}(); - } - `; - } return ` - fn ${funcName}Index(globalIndex : i32) -> f32{ + fn ${funcName}Index(globalIndex : i32) -> ${typeSnippet(component)}{ return get${texFuncSnippet}(); } - fn ${funcName}Coords(coords : ${type}) -> f32{ + fn ${funcName}Coords(coords : ${type}) -> ${typeSnippet(component)}{ return get${texFuncSnippet}(); } `; @@ -578,58 +555,43 @@ function getInputByOutputSnippet( const shapeStr = `uniforms.${texName.charAt(0).toLowerCase() + texName.slice(1)}Shape`; const rankStr = `${inRank}D`; - if (isVec4) { - return ` - fn ${funcName}Index(globalIndex : i32) -> vec4 { - var coords = getCoordsFromIndex(globalIndex); - ${coordsSnippet} - return ${texName}[getIndexFromCoords${rankStr}(${ - unpackedCoordsSnippet}, ${shapeStr}) / 4]; - } - - fn ${funcName}Coords(coordsIn : ${type}) -> vec4 { - var coords = coordsIn; - ${coordsSnippet} - return ${texName}[getIndexFromCoords${rankStr}(${ - unpackedCoordsSnippet}, ${shapeStr}) / 4]; - } - `; - } return ` - fn ${funcName}Index(globalIndex : i32) -> f32 { + fn ${funcName}Index(globalIndex : i32) -> ${typeSnippet(component)} { var coords = getCoordsFromIndex(globalIndex); ${coordsSnippet} - return f32(${texName}[getIndexFromCoords${rankStr}(${ - unpackedCoordsSnippet}, ${shapeStr})]); + return ${typeSnippet(component)}(${texName}[getIndexFromCoords${rankStr}(${ + unpackedCoordsSnippet}, ${shapeStr})${ + component === 1 ? '' : ` / ${component}`}]); } - fn ${funcName}Coords(coordsIn : ${type}) -> f32 { + fn ${funcName}Coords(coordsIn : ${type}) -> ${typeSnippet(component)} { var coords = coordsIn; ${coordsSnippet} - return f32(${texName}[getIndexFromCoords${rankStr}(${ - unpackedCoordsSnippet}, ${shapeStr})]); + return ${typeSnippet(component)}(${texName}[getIndexFromCoords${rankStr}(${ + unpackedCoordsSnippet}, ${shapeStr})${ + component === 1 ? '' : ` / ${component}`}]); } `; } function getInputSnippet( - inputInfo: InputInfo, outShape: number[], isVec4: boolean, + inputInfo: InputInfo, outShape: number[], component: number, isFlatDispatchLayout: boolean): string { - let res = getInputAtCoordsSnippet(inputInfo, isVec4); + let res = getInputAtCoordsSnippet(inputInfo, component); const inShape = inputInfo.shape; if (inShape.length <= outShape.length) { res += getInputByOutputSnippet( - inputInfo, outShape, isVec4, isFlatDispatchLayout); + inputInfo, outShape, component, isFlatDispatchLayout); } return res; } /** - * Generates getOutputCoords() function that computes output coordinates from - * dispatch geometry to reduce arithmetic. + * Generates getOutputCoords() function that computes output coordinates + * from dispatch geometry to reduce arithmetic. */ function getOutputCoordsSnippet( outShape: number[], @@ -768,72 +730,70 @@ function isFlatDispatch(program: WebGPUProgram): boolean { return program.dispatch[1] === 1 && program.dispatch[2] === 1; } -export function mapToWgslTypes(type: DataType, isVec4: boolean): WGSLDataType| - DataType { +export function mapToWgslTypes(type: DataType, component = 1): WGSLDataType { if (type === 'float32') { - return isVec4 ? 'vec4' : 'f32'; - } else if (type === 'int32') { - return isVec4 ? 'vec4' : 'i32'; - } else if (type === 'bool') { - // Type 'bool' cannot be used in storage class, - // https://www.w3.org/TR/WGSL/#host-shareable-types. - return isVec4 ? 'vec4' : 'i32'; + switch (component) { + case 1: + return 'f32'; + case 2: + return 'vec2'; + case 3: + return 'vec3'; + case 4: + return 'vec4'; + default: + throw new Error(`${component}-component is not supported.`); + } + } else if (type === 'int32' || type === 'bool') { + switch (component) { + case 1: + return 'i32'; + case 2: + return 'vec2'; + case 3: + return 'vec3'; + case 4: + return 'vec4'; + default: + throw new Error(`${component}-component is not supported.`); + } } - - return type; + throw new Error(`type ${type} is not supported.`); } function setOutputSnippet( - outShape: number[], outBufferType: DataType, isVec4: boolean): string { + outShape: number[], outBufferType: DataType, component: number): string { const outRank = outShape.length; - const wgslType = mapToWgslTypes(outBufferType, isVec4); - let snippet; - if (isVec4) { - snippet = `fn setOutputAtIndex(flatIndex : i32, value : vec4) { + const wgslType = mapToWgslTypes(outBufferType, component); + let snippet = + `fn setOutputAtIndex(flatIndex : i32, value : ${typeSnippet(component)}) { result[flatIndex] = ${wgslType}(value); } - fn setOutputAtIndexI32(flatIndex : i32, value : vec4) { - result[flatIndex] = ${wgslType}(value); - }`; - } else { - snippet = `fn setOutputAtIndex(flatIndex : i32, value : f32) { + + fn setOutputAtIndexI32(flatIndex : i32, value : ${ + typeSnippet(component, 'i')}) { result[flatIndex] = ${wgslType}(value); } - fn setOutputAtIndexI32(flatIndex : i32, value : i32) { - result[flatIndex] = ${wgslType}(value); - }`; - } + `; if (outRank >= 2) { const dims = ['d0', 'd1', 'd2', 'd3', 'd4', 'd5'].slice(0, outRank); const type = getCoordsDataType(outRank); - if (isVec4) { - snippet += ` - fn setOutputAtCoords(${ - dims.map(d => `${d} : i32`).join(', ')}, value : vec4) { - let flatIndex = getOutputIndexFromCoords(${type}(${dims.join(', ')})); - setOutputAtIndex(flatIndex / 4, value); - } - fn setOutputAtCoordsI32(${ - dims.map(d => `${d} : i32`).join(', ')}, value : vec4) { + snippet += ` + fn setOutputAtCoords(${dims.map(d => `${d} : i32`).join(', ')}, value : ${ + typeSnippet(component)}) { let flatIndex = getOutputIndexFromCoords(${type}(${dims.join(', ')})); - setOutputAtIndexI32(flatIndex / 4, value); - } - `; - } else { - snippet += ` - fn setOutputAtCoords(${ - dims.map(d => `${d} : i32`).join(', ')}, value : f32) { - let flatIndex = getOutputIndexFromCoords(${type}(${dims.join(', ')})); - setOutputAtIndex(flatIndex, value); + setOutputAtIndex(flatIndex${ + component === 1 ? '' : ` / ${component}`}, value); } fn setOutputAtCoordsI32(${ - dims.map(d => `${d} : i32`).join(', ')}, value : i32) { + dims.map(d => `${d} : i32`).join(', ')}, value : ${ + typeSnippet(component, 'i')}) { let flatIndex = getOutputIndexFromCoords(${type}(${dims.join(', ')})); - setOutputAtIndexI32(flatIndex, value); + setOutputAtIndexI32(flatIndex${ + component === 1 ? '' : ` / ${component}`}, value); } `; - } } return snippet; From 24342987f7cf03f63239a87be81bb1fd681d6e6c Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 1 Mar 2023 10:44:37 +0800 Subject: [PATCH 2/3] address comments --- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 2 +- tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts | 2 +- .../src/conv_backprop_mm_webgpu.ts | 2 +- .../src/conv_backprop_webgpu.ts | 2 +- .../src/matmul_splitK_webgpu.ts | 27 ++++--- tfjs-backend-webgpu/src/scatter_webgpu.ts | 5 +- tfjs-backend-webgpu/src/webgpu_program.ts | 71 +++++++------------ 7 files changed, 44 insertions(+), 67 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 0fb1bf46b37..3b17dd30850 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -24,7 +24,7 @@ import {computeDispatch, flatDispatchLayout} from './webgpu_util'; export class BinaryOpProgram implements WebGPUProgram { dispatch: [number, number, number]; dispatchLayout: {x: number[]}; - outputComponent = 1; + outputComponent: number; isVec4: boolean; op: BinaryOpType; outputShape: number[]; diff --git a/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts b/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts index 77be2d8a9b5..53238914f73 100644 --- a/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts @@ -176,7 +176,7 @@ export class Conv2DMMProgram implements WebGPUProgram { tileInner: number; innerElementSize: number; isVec4?: boolean; - outputComponent = 1; + outputComponent: number; private sequentialAccessByThreads: boolean; constructor( diff --git a/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts b/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts index 458fb5d756e..66bb02ae36a 100644 --- a/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts +++ b/tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts @@ -123,7 +123,7 @@ export class Conv2DDerInputMMProgram implements WebGPUProgram { workgroupSize: [number, number, number]; elementsPerThread: [number, number, number]; isVec4?: boolean; - outputComponent = 1; + outputComponent: number; constructor(convInfo: backend_util.Conv2DInfo) { this.outputShape = convInfo.inShape; diff --git a/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts b/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts index 958bd915c2a..4032432e359 100644 --- a/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts +++ b/tfjs-backend-webgpu/src/conv_backprop_webgpu.ts @@ -32,7 +32,7 @@ export class Conv2DDerInputProgram implements WebGPUProgram { size = false; isVec4 = false; workPerThread = 1; - outputComponent = 1; + outputComponent: number; constructor(convInfo: backend_util.Conv2DInfo) { this.outputShape = convInfo.inShape; diff --git a/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts b/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts index df0f013fff0..aefb3c9e7e0 100644 --- a/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts @@ -35,8 +35,7 @@ export class MatMulSplitKProgram implements WebGPUProgram { transposeA: boolean; transposeB: boolean; atomic = true; - isVec4 = false; - outputComponent = 1; + outputComponent: number; splitedDimInner = 128; constructor( @@ -47,12 +46,12 @@ export class MatMulSplitKProgram implements WebGPUProgram { () => 'MatMulSplitKProgram only supports batch = 1.'); this.outputShape = outputShape; this.dispatchLayout = {x: [2], y: [1], z: [0, 3]}; - this.isVec4 = (transposeA && this.outputShape[1] % 4 === 0 || - !transposeA && dimInner % 4 === 0) && + const isVec4 = (transposeA && this.outputShape[1] % 4 === 0 || + !transposeA && dimInner % 4 === 0) && this.outputShape[2] % 4 === 0; this.elementsPerThread = [4, 4, this.splitedDimInner]; - this.outputComponent = this.isVec4 ? 4 : 1; - if (!this.isVec4) { + this.outputComponent = isVec4 ? 4 : 1; + if (!isVec4) { if (this.outputShape[1] < 16) { this.elementsPerThread[1] = 1; } @@ -72,11 +71,11 @@ export class MatMulSplitKProgram implements WebGPUProgram { this.transposeA = transposeA; this.transposeB = transposeB; this.shaderKey = `matMulSplitK_${transposeA}_${transposeB}_${ - this.elementsPerThread}_${this.isVec4}`; + this.elementsPerThread}_${this.outputComponent}`; } getUserCode(): string { - const component = this.isVec4 ? 4 : 1; + const component = this.outputComponent; const userCode = ` ${ matMulReadFnSource( @@ -98,12 +97,12 @@ export class MatMulSplitKProgram implements WebGPUProgram { } } ${ - this.isVec4 ? makeMatMulPackedVec4Source( - this.elementsPerThread, this.workgroupSize, - this.transposeA, 32, true, this.splitedDimInner) : - makeMatMulPackedSource( - this.elementsPerThread, this.workgroupSize, - this.transposeA, 32, true, this.splitedDimInner)} + component === 4 ? makeMatMulPackedVec4Source( + this.elementsPerThread, this.workgroupSize, + this.transposeA, 32, true, this.splitedDimInner) : + makeMatMulPackedSource( + this.elementsPerThread, this.workgroupSize, + this.transposeA, 32, true, this.splitedDimInner)} `; return userCode; } diff --git a/tfjs-backend-webgpu/src/scatter_webgpu.ts b/tfjs-backend-webgpu/src/scatter_webgpu.ts index de4566380f4..e866424a33a 100644 --- a/tfjs-backend-webgpu/src/scatter_webgpu.ts +++ b/tfjs-backend-webgpu/src/scatter_webgpu.ts @@ -16,8 +16,9 @@ */ import {DataType} from '@tensorflow/tfjs-core'; + import {atomicAddSnippet} from './shader_util'; -import {getCoordsDataType, getMainHeaderString as main, mapToWgslTypes, WebGPUProgram} from './webgpu_program'; +import {dataTypeToGPUType, getCoordsDataType, getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; import {computeDispatch, flatDispatchLayout} from './webgpu_util'; export class ScatterProgram implements WebGPUProgram { @@ -107,7 +108,7 @@ export class ScatterProgram implements WebGPUProgram { flattenedIndex = flattenedIndex + indexInside * ${strideString}; } let updateValue = - ${mapToWgslTypes(this.type)}(${updatesSnippet}); + ${dataTypeToGPUType(this.type)}(${updatesSnippet}); let flatIndex = getOutputIndexFromCoords(${outCoordsString}); ${ diff --git a/tfjs-backend-webgpu/src/webgpu_program.ts b/tfjs-backend-webgpu/src/webgpu_program.ts index 5242fc01e07..235f6a6ba1a 100644 --- a/tfjs-backend-webgpu/src/webgpu_program.ts +++ b/tfjs-backend-webgpu/src/webgpu_program.ts @@ -68,18 +68,18 @@ export const compileProgram = return pipeline; }; -export const typeSnippet = (component: number, type = 'f') => { +export const typeSnippet = (component: number, type = 'f32') => { switch (component) { case 1: - return `${type}32`; + return `${type}`; case 2: - return `vec2<${type}32>`; + return `vec2<${type}>`; case 3: - return `vec3<${type}32>`; + return `vec3<${type}>`; case 4: - return `vec4<${type}32>`; + return `vec4<${type}>`; default: - throw new Error(`${component}-component is not supported.`); + throw new Error(`${component}-component ${type} is not supported.`); } }; @@ -174,6 +174,8 @@ function makeShader( const prefixSnippets: string[] = []; const flatWorkgroupSize = program.workgroupSize[0] * program.workgroupSize[1] * program.workgroupSize[2]; + program.outputComponent = + program.outputComponent ? program.outputComponent : 1 prefixSnippets.push(` var localId: vec3; @@ -204,7 +206,7 @@ function makeShader( }; @group(0) @binding(0) var result: array<${ - mapToWgslTypes(outputData.dtype, program.outputComponent)}>; + dataTypeToGPUType(outputData.dtype, program.outputComponent)}>; @group(0) @binding(2) var uniforms: Uniform; `); const useGlobalIndex = isFlatDispatchLayout(program); @@ -250,15 +252,16 @@ function makeShader( } else { prefixSnippets.push(` @group(0) @binding(0) var result: array<${ - mapToWgslTypes(outputData.dtype, program.outputComponent)}>; + dataTypeToGPUType(outputData.dtype, program.outputComponent)}>; `); } program.variableNames.forEach((x, i) => { prefixSnippets.push(` @group(0) @binding(${1 + i}) var ${x}: array<${ program.variableComponents ? - mapToWgslTypes(inputInfo[i].dtype, program.variableComponents[i]) : - mapToWgslTypes(inputInfo[i].dtype, program.outputComponent)}>; + dataTypeToGPUType( + inputInfo[i].dtype, program.variableComponents[i]) : + dataTypeToGPUType(inputInfo[i].dtype, program.outputComponent)}>; `); }); @@ -279,8 +282,7 @@ function makeShader( ]; if (!program.atomic) { sources.push(setOutputSnippet( - outputData.shape, outputData.dtype, - program.outputComponent ? program.outputComponent : 1)); + outputData.shape, outputData.dtype, program.outputComponent)); } const inputSnippet = @@ -288,9 +290,8 @@ function makeShader( .map( (x, i) => getInputSnippet( x, outputData.shape, - program.variableComponents ? - program.variableComponents[i] : - program.outputComponent ? program.outputComponent : 1, + program.variableComponents ? program.variableComponents[i] : + program.outputComponent, program.dispatchLayout.x.length === outputData.shape.length)) .join('\n'); sources.push(inputSnippet); @@ -396,8 +397,6 @@ const isInfSnippet = ` type InputInfo = { dtype: DataType; shape: number[]; name: string; }; -export type WGSLDataType = 'f32'|'vec2'|'vec3'|'vec4'|'i32'| - 'vec2'|'vec3'|'vec4'; /** * Derives logical coordinates from a flat index. Performs integer division @@ -730,33 +729,11 @@ function isFlatDispatch(program: WebGPUProgram): boolean { return program.dispatch[1] === 1 && program.dispatch[2] === 1; } -export function mapToWgslTypes(type: DataType, component = 1): WGSLDataType { +export function dataTypeToGPUType(type: DataType, component = 1) { if (type === 'float32') { - switch (component) { - case 1: - return 'f32'; - case 2: - return 'vec2'; - case 3: - return 'vec3'; - case 4: - return 'vec4'; - default: - throw new Error(`${component}-component is not supported.`); - } + return typeSnippet(component, 'f32'); } else if (type === 'int32' || type === 'bool') { - switch (component) { - case 1: - return 'i32'; - case 2: - return 'vec2'; - case 3: - return 'vec3'; - case 4: - return 'vec4'; - default: - throw new Error(`${component}-component is not supported.`); - } + return typeSnippet(component, 'i32'); } throw new Error(`type ${type} is not supported.`); } @@ -764,15 +741,15 @@ export function mapToWgslTypes(type: DataType, component = 1): WGSLDataType { function setOutputSnippet( outShape: number[], outBufferType: DataType, component: number): string { const outRank = outShape.length; - const wgslType = mapToWgslTypes(outBufferType, component); + const gpuType = dataTypeToGPUType(outBufferType, component); let snippet = `fn setOutputAtIndex(flatIndex : i32, value : ${typeSnippet(component)}) { - result[flatIndex] = ${wgslType}(value); + result[flatIndex] = ${gpuType}(value); } fn setOutputAtIndexI32(flatIndex : i32, value : ${ - typeSnippet(component, 'i')}) { - result[flatIndex] = ${wgslType}(value); + typeSnippet(component, 'i32')}) { + result[flatIndex] = ${gpuType}(value); } `; if (outRank >= 2) { @@ -788,7 +765,7 @@ function setOutputSnippet( } fn setOutputAtCoordsI32(${ dims.map(d => `${d} : i32`).join(', ')}, value : ${ - typeSnippet(component, 'i')}) { + typeSnippet(component, 'i32')}) { let flatIndex = getOutputIndexFromCoords(${type}(${dims.join(', ')})); setOutputAtIndexI32(flatIndex${ component === 1 ? '' : ` / ${component}`}, value); From aa8a033e8a28417478da5e18bd4781044fa65d83 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 1 Mar 2023 11:10:43 +0800 Subject: [PATCH 3/3] Fix build error --- tfjs-backend-webgpu/src/webgpu_program.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-webgpu/src/webgpu_program.ts b/tfjs-backend-webgpu/src/webgpu_program.ts index 235f6a6ba1a..70f149f3191 100644 --- a/tfjs-backend-webgpu/src/webgpu_program.ts +++ b/tfjs-backend-webgpu/src/webgpu_program.ts @@ -175,7 +175,7 @@ function makeShader( const flatWorkgroupSize = program.workgroupSize[0] * program.workgroupSize[1] * program.workgroupSize[2]; program.outputComponent = - program.outputComponent ? program.outputComponent : 1 + program.outputComponent ? program.outputComponent : 1; prefixSnippets.push(` var localId: vec3;