Skip to content

Commit

Permalink
[WebGPU] support AvgPool3DGrad kernel (#7440)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao authored Mar 2, 2023
1 parent 71faceb commit 9ea4a8e
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export class AvgPool2DBackpropProgram implements WebGPUProgram {
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);

this.shaderKey = `avg_pool2d_backprop`;
this.shaderKey = `avgPool2DBackprop`;
}

getUserCode(): string {
Expand Down Expand Up @@ -85,3 +85,79 @@ export class AvgPool2DBackpropProgram implements WebGPUProgram {
return userCode;
}
}

export class AvgPool3DBackpropProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
variableNames = ['dy'];
uniforms = `strides : vec3<i32>, pads : vec3<i32>, filterDims : vec3<i32>,
outDepth : i32, outHeight : i32, outWidth : i32, avgMultiplier : f32,`;
workgroupSize: [number, number, number] = [64, 1, 1];
size = true;

constructor(convInfo: backend_util.Conv3DInfo) {
this.outputShape = convInfo.inShape;

this.dispatchLayout = flatDispatchLayout(this.outputShape);

this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);

this.shaderKey = `avgPool3DBackprop`;
}

getUserCode(): string {
const userCode = `
${main('index')} {
if (index < uniforms.size) {
let coords = getCoordsFromIndex(index);
let batch = coords.x;
let ch = coords.u;
let dyCorner = vec3<i32>(coords.y, coords.z, coords.w) - uniforms.pads;
let dyDCorner = dyCorner.x;
let dyRCorner = dyCorner.y;
let dyCCorner = dyCorner.z;
// Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get
// dx(xD, xR, xC, ch).
// ? = to be determined. : = across all values in that axis.
var dotProd = 0.0;
for (var wD = 0; wD < uniforms.filterDims[0]; wD++) {
let dyD = f32(dyDCorner + wD) / f32(uniforms.strides[0]);
if (dyD < 0.0 || dyD >= f32(uniforms.outDepth) || fract(dyD) > 0.0) {
continue;
}
let idyD = i32(dyD);
for (var wR = 0; wR < uniforms.filterDims[1]; wR++) {
let dyR = f32(dyRCorner + wR) / f32(uniforms.strides[1]);
if (dyR < 0.0 || dyR >= f32(uniforms.outHeight) || fract(dyR) > 0.0) {
continue;
}
let idyR = i32(dyR);
for (var wC = 0; wC < uniforms.filterDims[2]; wC++) {
let dyC = f32(dyCCorner + wC) / f32(uniforms.strides[2]);
if (dyC < 0.0 || dyC >= f32(uniforms.outWidth) || fract(dyC) > 0.0) {
continue;
}
let idyC = i32(dyC);
let dyValue = getDy(batch, idyD, idyR, idyC, ch);
dotProd += dyValue * uniforms.avgMultiplier;
}
}
}
setOutputAtIndex(index, dotProd);
}
}
`;
return userCode;
}
}
71 changes: 71 additions & 0 deletions tfjs-backend-webgpu/src/kernels/AvgPool3DGrad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {AvgPool3DGrad, AvgPool3DGradAttrs, AvgPool3DGradInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core';

import {AvgPool3DBackpropProgram} from '../avg_pool_backprop_webgpu';
import {WebGPUBackend} from '../backend_webgpu';

export function avgPool3DGrad(args: {
inputs: AvgPool3DGradInputs,
backend: WebGPUBackend,
attrs: AvgPool3DGradAttrs
}): TensorInfo {
const {inputs, backend, attrs} = args;
const {dy, input} = inputs;
const x = input;
const {filterSize, strides, pad, dimRoundingMode} = attrs;

const convInfo = backend_util.computePool3DInfo(
x.shape as [number, number, number, number, number], filterSize, strides,
1 /* dilations */, pad, dimRoundingMode);
const program = new AvgPool3DBackpropProgram(convInfo);
const avgMultiplier =
1 / (convInfo.filterDepth * convInfo.filterHeight * convInfo.filterWidth);
const uniformData = [
{
type: 'int32',
data: [convInfo.strideDepth, convInfo.strideHeight, convInfo.strideWidth]
},
{
type: 'int32',
data: [
convInfo.effectiveFilterDepth - 1 - convInfo.padInfo.front,
convInfo.effectiveFilterHeight - 1 - convInfo.padInfo.top,
convInfo.effectiveFilterWidth - 1 - convInfo.padInfo.left
]
},
{
type: 'int32',
data: [
convInfo.effectiveFilterDepth, convInfo.effectiveFilterHeight,
convInfo.effectiveFilterWidth
]
},
{type: 'int32', data: [convInfo.outDepth]},
{type: 'int32', data: [convInfo.outHeight]},
{type: 'int32', data: [convInfo.outWidth]},
{type: 'float32', data: [avgMultiplier]}
];
return backend.runWebGPUProgram(program, [dy], x.dtype, uniformData);
}

export const avgPool3DGradConfig: KernelConfig = {
kernelName: AvgPool3DGrad,
backendName: 'webgpu',
kernelFunc: avgPool3DGrad as unknown as KernelFunc
};
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/kernels/AvgPoolGrad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import {AvgPoolGrad, AvgPoolGradAttrs, AvgPoolGradInputs, backend_util, KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core';

import {AvgPool2DBackpropProgram} from '../avg_pool2d_backprop_webgpu';
import {AvgPool2DBackpropProgram} from '../avg_pool_backprop_webgpu';
import {WebGPUBackend} from '../backend_webgpu';
import {assertNotComplex} from '../webgpu_util';

Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {atan2Config} from './kernels/Atan2';
import {atanhConfig} from './kernels/Atanh';
import {avgPoolConfig} from './kernels/AvgPool';
import {avgPool3DConfig} from './kernels/AvgPool3D';
import {avgPool3DGradConfig} from './kernels/AvgPool3DGrad';
import {avgPoolGradConfig} from './kernels/AvgPoolGrad';
import {batchMatMulConfig} from './kernels/BatchMatMul';
import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND';
Expand Down Expand Up @@ -189,6 +190,7 @@ const kernelConfigs: KernelConfig[] = [
atanhConfig,
avgPoolConfig,
avgPool3DConfig,
avgPool3DGradConfig,
avgPoolGradConfig,
batchMatMulConfig,
batchToSpaceNDConfig,
Expand Down
1 change: 0 additions & 1 deletion tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ const TEST_FILTERS: TestFilter[] = [
{
include: ' webgpu ',
excludes: [
'avgPool3dBackprop ',
'raggedGather ',
'raggedRange ',
'raggedTensorToTensor ',
Expand Down

0 comments on commit 9ea4a8e

Please sign in to comment.