Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU] Support ResizeBilinearGrad kernel #7385

Merged
merged 3 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tfjs-backend-webgpu/src/kernels/ResizeBilinearGrad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, KernelFunc, ResizeBilinearGrad, ResizeBilinearGradAttrs, ResizeBilinearGradInputs, TensorInfo} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {ResizeBilinearBackpropProgram} from '../resize_bilinear_backprop_webgpu';

export function resizeBilinearGrad(args: {
inputs: ResizeBilinearGradInputs,
backend: WebGPUBackend,
attrs: ResizeBilinearGradAttrs
}): TensorInfo {
const {inputs, backend, attrs} = args;
const {images, dy} = inputs;
const {alignCorners} = attrs;

const [, xHeight, xWidth, ] =
images.shape as [number, number, number, number];
const [, yHeight, yWidth] = dy.shape as [number, number, number, number];

const effectiveXSize: [number, number] = [
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
];

const effectiveYSize: [number, number] = [
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
];

const heightScale = effectiveXSize[0] / effectiveYSize[0];
const widthScale = effectiveXSize[1] / effectiveYSize[1];

const invHeightScale = 1 / heightScale;
const invWidthScale = 1 / widthScale;

// This defines the size of the window of values around a particular
// index in dy that we want to search for contributions to dx.
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;

const program = new ResizeBilinearBackpropProgram(
images.shape as [number, number, number, number], alignCorners);
const uniformData = [
{type: 'int32', data: effectiveXSize},
{type: 'int32', data: effectiveYSize},
{type: 'float32', data: [heightScale]},
{type: 'float32', data: [widthScale]},
{type: 'float32', data: [invHeightScale]},
{type: 'float32', data: [invWidthScale]},
{type: 'int32', data: [winHeight]}, {type: 'int32', data: [winWidth]}
];
return backend.runWebGPUProgram(program, [dy], dy.dtype, uniformData);
}

export const resizeBilinearGradConfig: KernelConfig = {
kernelName: ResizeBilinearGrad,
backendName: 'webgpu',
kernelFunc: resizeBilinearGrad as unknown as KernelFunc
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ import {reluConfig} from './kernels/Relu';
import {relu6Config} from './kernels/Relu6';
import {reshapeConfig} from './kernels/Reshape';
import {resizeBilinearConfig} from './kernels/ResizeBilinear';
import {resizeBilinearGradConfig} from './kernels/ResizeBilinearGrad';
import {resizeNearestNeighborConfig} from './kernels/ResizeNearestNeighbor';
import {resizeNearestNeighborGradConfig} from './kernels/ResizeNearestNeighborGrad';
import {reverseConfig} from './kernels/Reverse';
Expand Down Expand Up @@ -277,6 +278,7 @@ const kernelConfigs: KernelConfig[] = [
relu6Config,
reshapeConfig,
resizeBilinearConfig,
resizeBilinearGradConfig,
resizeNearestNeighborConfig,
resizeNearestNeighborGradConfig,
reverseConfig,
Expand Down
124 changes: 124 additions & 0 deletions tfjs-backend-webgpu/src/resize_bilinear_backprop_webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

export class ResizeBilinearBackpropProgram implements WebGPUProgram {
outputShape: number[];
shaderKey: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
variableNames = ['dy'];
uniforms =
`effectiveXSize : vec2<i32>, effectiveYSize : vec2<i32>, heightScale : f32, widthScale : f32,
invHeightScale : f32, invWidthScale : f32, winHeight : i32, winWidth : i32,`;
workgroupSize: [number, number, number] = [64, 1, 1];
alignCorners: boolean;
size = true;

constructor(
inputShape: [number, number, number, number], alignCorners: boolean) {
this.outputShape = inputShape;

this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);

this.alignCorners = alignCorners;
this.shaderKey = `resizeBilinearBackprop_${alignCorners}`;
}

getUserCode(): string {
const userCode = `
${main('index')} {
if (index < uniforms.size) {
let coords = getOutputCoords();
let b = coords[0];
let d = coords[3];
let r = coords[1];
let c = coords[2];

var accumulator = 0.0;

// Compute bounds for where in dy we will look
let startRLerp = floor(f32(r) * uniforms.invHeightScale);
let startDyR = i32(startRLerp - f32(uniforms.winHeight / 2));

let startCLerp = floor(f32(c) * uniforms.invWidthScale);
let startDyC = i32(startCLerp - f32(uniforms.winWidth / 2));

// Loop over dy
for (var dyROffset = 0; dyROffset < uniforms.winHeight; dyROffset++) {
let dyR = startDyR + dyROffset;

// Guard against the window exceeding the bounds of dy
if (dyR < 0 || dyR >= uniforms.dyShape[1]) {
continue;
}

for (var dyCOffset = 0; dyCOffset < uniforms.winWidth; dyCOffset++) {
let dyC = startDyC + dyCOffset;

// Guard against the window exceeding the bounds of dy
if (dyC < 0 || dyC >= uniforms.dyShape[2]) {
continue;
}

let dxR = f32(dyR) * uniforms.heightScale;
let topDxRIndex = i32(floor(dxR));
let bottomDxRIndex = i32(min(ceil(dxR), f32(uniforms.outShape[1] - 1)));
let dxRLerp = dxR - f32(topDxRIndex);
let inverseDxRLerp = 1.0 - dxRLerp;

let dxC = f32(dyC) * uniforms.widthScale;
let leftDxCIndex = i32(floor(dxC));
let rightDxCIndex = i32(min(ceil(dxC), f32(uniforms.outShape[2] - 1)));
let dxCLerp = dxC - f32(leftDxCIndex);
let inverseDxCLerp = 1.0 - dxCLerp;

if (r == topDxRIndex && c == leftDxCIndex) {
// topLeft
accumulator +=
getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;
}

if (r == topDxRIndex && c == rightDxCIndex) {
// topRight
accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;
}

if (r == bottomDxRIndex && c == leftDxCIndex) {
// bottomLeft
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;
}

if (r == bottomDxRIndex && c == rightDxCIndex) {
// bottomRight
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;
}
}
}
// End loop over dy

setOutputAtIndex(index, accumulator);
}
}
`;
return userCode;
}
}
6 changes: 0 additions & 6 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ const TEST_FILTERS: TestFilter[] = [
'gradients', // Not yet implemented
]
},
{
startsWith: 'resizeBilinear ',
excludes: [
'gradients', // Not yet implemented
]
},

// exclude unsupported kernels and to be fixed cases
{
Expand Down