Skip to content

Commit

Permalink
[WebGPU] Support ResizeBilinearGrad kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao committed Feb 20, 2023
1 parent 7921dd5 commit ff60fc2
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 6 deletions.
74 changes: 74 additions & 0 deletions tfjs-backend-webgpu/src/kernels/ResizeBilinearGrad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, KernelFunc, ResizeBilinearGrad, ResizeBilinearGradAttrs, ResizeBilinearGradInputs, TensorInfo} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {ResizeBilinearBackpropProgram} from '../resize_bilinear_backprop_webgpu';

export function resizeBilinearGrad(args: {
inputs: ResizeBilinearGradInputs,
backend: WebGPUBackend,
attrs: ResizeBilinearGradAttrs
}): TensorInfo {
const {inputs, backend, attrs} = args;
const {images, dy} = inputs;
const {alignCorners} = attrs;

const [, xHeight, xWidth, ] = images.shape as [number, number, number, number];
const [, yHeight, yWidth] = dy.shape as [number, number, number, number];

const effectiveXSize: [number, number] = [
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
];

const effectiveYSize: [number, number] = [
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
];

const heightScale = effectiveXSize[0] / effectiveYSize[0];
const widthScale = effectiveXSize[1] / effectiveYSize[1];

const invHeightScale = 1 / heightScale;
const invWidthScale = 1 / widthScale;

// This defines the size of the window of values around a particular
// index in dy that we want to search for contributions to dx.
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;

const program = new ResizeBilinearBackpropProgram(
images.shape as [number, number, number, number], alignCorners);
const uniformData = [
{type: 'int32', data: effectiveXSize},
{type: 'int32', data: effectiveYSize},
{type: 'float32', data: [heightScale]},
{type: 'float32', data: [widthScale]},
{type: 'float32', data: [invHeightScale]},
{type: 'float32', data: [invWidthScale]},
{type: 'int32', data: [winHeight]}, {type: 'int32', data: [winWidth]}
];
return backend.runWebGPUProgram(program, [dy], dy.dtype, uniformData);
}

export const resizeBilinearGradConfig: KernelConfig = {
kernelName: ResizeBilinearGrad,
backendName: 'webgpu',
kernelFunc: resizeBilinearGrad as unknown as KernelFunc
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ import {reluConfig} from './kernels/Relu';
import {relu6Config} from './kernels/Relu6';
import {reshapeConfig} from './kernels/Reshape';
import {resizeBilinearConfig} from './kernels/ResizeBilinear';
import {resizeBilinearGradConfig} from './kernels/ResizeBilinearGrad';
import {resizeNearestNeighborConfig} from './kernels/ResizeNearestNeighbor';
import {resizeNearestNeighborGradConfig} from './kernels/ResizeNearestNeighborGrad';
import {reverseConfig} from './kernels/Reverse';
Expand Down Expand Up @@ -277,6 +278,7 @@ const kernelConfigs: KernelConfig[] = [
relu6Config,
reshapeConfig,
resizeBilinearConfig,
resizeBilinearGradConfig,
resizeNearestNeighborConfig,
resizeNearestNeighborGradConfig,
reverseConfig,
Expand Down
124 changes: 124 additions & 0 deletions tfjs-backend-webgpu/src/resize_bilinear_backprop_webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

export class ResizeBilinearBackpropProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
variableNames = ['dy'];
uniforms =
`effectiveXSize : vec2<i32>, effectiveYSize : vec2<i32>, heightScale : f32, widthScale : f32,
invHeightScale : f32, invWidthScale : f32, winHeight : i32, winWidth : i32,`;
workgroupSize: [number, number, number] = [64, 1, 1];
alignCorners: boolean;
size = true;

constructor(
inputShape: [number, number, number, number], alignCorners: boolean) {
this.outputShape = inputShape;

this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);

this.alignCorners = alignCorners;
this.shaderKey = `resizeBilinearBackprop_${alignCorners}`;
}

getUserCode(): string {
const userCode = `
${main('index')} {
if (index < uniforms.size) {
let coords = getOutputCoords();
let b = coords[0];
let d = coords[3];
let r = coords[1];
let c = coords[2];
var accumulator = 0.0;
// Compute bounds for where in dy we will look
let startRLerp = floor(f32(r) * uniforms.invHeightScale);
let startDyR = i32(startRLerp - f32(uniforms.winHeight / 2));
let startCLerp = floor(f32(c) * uniforms.invWidthScale);
let startDyC = i32(startCLerp - f32(uniforms.winWidth / 2));
// Loop over dy
for (var dyROffset = 0; dyROffset < uniforms.winHeight; dyROffset++) {
let dyR = startDyR + dyROffset;
// Guard against the window exceeding the bounds of dy
if (dyR < 0 || dyR >= uniforms.dyShape[1]) {
continue;
}
for (var dyCOffset = 0; dyCOffset < uniforms.winWidth; dyCOffset++) {
let dyC = startDyC + dyCOffset;
// Guard against the window exceeding the bounds of dy
if (dyC < 0 || dyC >= uniforms.dyShape[2]) {
continue;
}
let dxR = f32(dyR) * uniforms.heightScale;
let topDxRIndex = i32(floor(dxR));
let bottomDxRIndex = i32(min(ceil(dxR), f32(uniforms.outShape[1] - 1)));
let dxRLerp = dxR - f32(topDxRIndex);
let inverseDxRLerp = 1.0 - dxRLerp;
let dxC = f32(dyC) * uniforms.widthScale;
let leftDxCIndex = i32(floor(dxC));
let rightDxCIndex = i32(min(ceil(dxC), f32(uniforms.outShape[2] - 1)));
let dxCLerp = dxC - f32(leftDxCIndex);
let inverseDxCLerp = 1.0 - dxCLerp;
if (r == topDxRIndex && c == leftDxCIndex) {
// topLeft
accumulator +=
getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;
}
if (r == topDxRIndex && c == rightDxCIndex) {
// topRight
accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;
}
if (r == bottomDxRIndex && c == leftDxCIndex) {
// bottomLeft
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;
}
if (r == bottomDxRIndex && c == rightDxCIndex) {
// bottomRight
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;
}
}
}
// End loop over dy
setOutputAtIndex(index, accumulator);
}
}
`;
return userCode;
}
}
6 changes: 0 additions & 6 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ const TEST_FILTERS: TestFilter[] = [
'gradients', // Not yet implemented
]
},
{
startsWith: 'resizeBilinear ',
excludes: [
'gradients', // Not yet implemented
]
},

// exclude unsupported kernels and to be fixed cases
{
Expand Down

0 comments on commit ff60fc2

Please sign in to comment.