From 5b519bf3d008ea24e44497746996d7e0e2e98007 Mon Sep 17 00:00:00 2001 From: xhcao Date: Wed, 22 Feb 2023 09:07:15 +0800 Subject: [PATCH] [WebGPU] Support ResizeBilinearGrad kernel (#7385) --- .../src/kernels/ResizeBilinearGrad.ts | 75 +++++++++++ .../src/register_all_kernels.ts | 2 + .../src/resize_bilinear_backprop_webgpu.ts | 124 ++++++++++++++++++ tfjs-backend-webgpu/src/setup_test.ts | 6 - 4 files changed, 201 insertions(+), 6 deletions(-) create mode 100644 tfjs-backend-webgpu/src/kernels/ResizeBilinearGrad.ts create mode 100644 tfjs-backend-webgpu/src/resize_bilinear_backprop_webgpu.ts diff --git a/tfjs-backend-webgpu/src/kernels/ResizeBilinearGrad.ts b/tfjs-backend-webgpu/src/kernels/ResizeBilinearGrad.ts new file mode 100644 index 00000000000..7af86c719bb --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/ResizeBilinearGrad.ts @@ -0,0 +1,75 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, ResizeBilinearGrad, ResizeBilinearGradAttrs, ResizeBilinearGradInputs, TensorInfo} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {ResizeBilinearBackpropProgram} from '../resize_bilinear_backprop_webgpu'; + +export function resizeBilinearGrad(args: { + inputs: ResizeBilinearGradInputs, + backend: WebGPUBackend, + attrs: ResizeBilinearGradAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {images, dy} = inputs; + const {alignCorners} = attrs; + + const [, xHeight, xWidth, ] = + images.shape as [number, number, number, number]; + const [, yHeight, yWidth] = dy.shape as [number, number, number, number]; + + const effectiveXSize: [number, number] = [ + (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, + (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth + ]; + + const effectiveYSize: [number, number] = [ + (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, + (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth + ]; + + const heightScale = effectiveXSize[0] / effectiveYSize[0]; + const widthScale = effectiveXSize[1] / effectiveYSize[1]; + + const invHeightScale = 1 / heightScale; + const invWidthScale = 1 / widthScale; + + // This defines the size of the window of values around a particular + // index in dy that we want to search for contributions to dx. + const winHeight = (Math.ceil(invHeightScale) * 2) + 2; + const winWidth = (Math.ceil(invWidthScale) * 2) + 2; + + const program = new ResizeBilinearBackpropProgram( + images.shape as [number, number, number, number], alignCorners); + const uniformData = [ + {type: 'int32', data: effectiveXSize}, + {type: 'int32', data: effectiveYSize}, + {type: 'float32', data: [heightScale]}, + {type: 'float32', data: [widthScale]}, + {type: 'float32', data: [invHeightScale]}, + {type: 'float32', data: [invWidthScale]}, + {type: 'int32', data: [winHeight]}, {type: 'int32', data: [winWidth]} + ]; + return backend.runWebGPUProgram(program, [dy], dy.dtype, uniformData); +} + +export const resizeBilinearGradConfig: KernelConfig = { + kernelName: ResizeBilinearGrad, + backendName: 'webgpu', + kernelFunc: resizeBilinearGrad as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 79270b00b43..29c457128ac 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -127,6 +127,7 @@ import {reluConfig} from './kernels/Relu'; import {relu6Config} from './kernels/Relu6'; import {reshapeConfig} from './kernels/Reshape'; import {resizeBilinearConfig} from './kernels/ResizeBilinear'; +import {resizeBilinearGradConfig} from './kernels/ResizeBilinearGrad'; import {resizeNearestNeighborConfig} from './kernels/ResizeNearestNeighbor'; import {resizeNearestNeighborGradConfig} from './kernels/ResizeNearestNeighborGrad'; import {reverseConfig} from './kernels/Reverse'; @@ -277,6 +278,7 @@ const kernelConfigs: KernelConfig[] = [ relu6Config, reshapeConfig, resizeBilinearConfig, + resizeBilinearGradConfig, resizeNearestNeighborConfig, resizeNearestNeighborGradConfig, reverseConfig, diff --git a/tfjs-backend-webgpu/src/resize_bilinear_backprop_webgpu.ts b/tfjs-backend-webgpu/src/resize_bilinear_backprop_webgpu.ts new file mode 100644 index 00000000000..d24a46ed3ac --- /dev/null +++ b/tfjs-backend-webgpu/src/resize_bilinear_backprop_webgpu.ts @@ -0,0 +1,124 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {computeDispatch, flatDispatchLayout} from './webgpu_util'; + +export class ResizeBilinearBackpropProgram implements WebGPUProgram { + outputShape: number[]; + shaderKey: string; + dispatchLayout: {x: number[]}; + dispatch: [number, number, number]; + variableNames = ['dy']; + uniforms = + `effectiveXSize : vec2, effectiveYSize : vec2, heightScale : f32, widthScale : f32, + invHeightScale : f32, invWidthScale : f32, winHeight : i32, winWidth : i32,`; + workgroupSize: [number, number, number] = [64, 1, 1]; + alignCorners: boolean; + size = true; + + constructor( + inputShape: [number, number, number, number], alignCorners: boolean) { + this.outputShape = inputShape; + + this.dispatchLayout = flatDispatchLayout(this.outputShape); + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + + this.alignCorners = alignCorners; + this.shaderKey = `resizeBilinearBackprop_${alignCorners}`; + } + + getUserCode(): string { + const userCode = ` + ${main('index')} { + if (index < uniforms.size) { + let coords = getOutputCoords(); + let b = coords[0]; + let d = coords[3]; + let r = coords[1]; + let c = coords[2]; + + var accumulator = 0.0; + + // Compute bounds for where in dy we will look + let startRLerp = floor(f32(r) * uniforms.invHeightScale); + let startDyR = i32(startRLerp - f32(uniforms.winHeight / 2)); + + let startCLerp = floor(f32(c) * uniforms.invWidthScale); + let startDyC = i32(startCLerp - f32(uniforms.winWidth / 2)); + + // Loop over dy + for (var dyROffset = 0; dyROffset < uniforms.winHeight; dyROffset++) { + let dyR = startDyR + dyROffset; + + // Guard against the window exceeding the bounds of dy + if (dyR < 0 || dyR >= uniforms.dyShape[1]) { + continue; + } + + for (var dyCOffset = 0; dyCOffset < uniforms.winWidth; dyCOffset++) { + let dyC = startDyC + dyCOffset; + + // Guard against the window exceeding the bounds of dy + if (dyC < 0 || dyC >= uniforms.dyShape[2]) { + continue; + } + + let dxR = f32(dyR) * uniforms.heightScale; + let topDxRIndex = i32(floor(dxR)); + let bottomDxRIndex = i32(min(ceil(dxR), f32(uniforms.outShape[1] - 1))); + let dxRLerp = dxR - f32(topDxRIndex); + let inverseDxRLerp = 1.0 - dxRLerp; + + let dxC = f32(dyC) * uniforms.widthScale; + let leftDxCIndex = i32(floor(dxC)); + let rightDxCIndex = i32(min(ceil(dxC), f32(uniforms.outShape[2] - 1))); + let dxCLerp = dxC - f32(leftDxCIndex); + let inverseDxCLerp = 1.0 - dxCLerp; + + if (r == topDxRIndex && c == leftDxCIndex) { + // topLeft + accumulator += + getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp; + } + + if (r == topDxRIndex && c == rightDxCIndex) { + // topRight + accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp; + } + + if (r == bottomDxRIndex && c == leftDxCIndex) { + // bottomLeft + accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp; + } + + if (r == bottomDxRIndex && c == rightDxCIndex) { + // bottomRight + accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp; + } + } + } + // End loop over dy + + setOutputAtIndex(index, accumulator); + } + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index ec6697527c8..3b75aba346c 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -76,12 +76,6 @@ const TEST_FILTERS: TestFilter[] = [ 'gradients', // Not yet implemented ] }, - { - startsWith: 'resizeBilinear ', - excludes: [ - 'gradients', // Not yet implemented - ] - }, // exclude unsupported kernels and to be fixed cases {