From 8b40572706257982e498f6b69ebf8d1312cdf39e Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Wed, 17 Aug 2022 10:44:17 +0800 Subject: [PATCH 1/3] Strengthen frame readiness while testing fromPixels with video element (#6751) This fixes the error in WebGPU-backed unit tests: OperationError: Failed to execute 'importExternalTexture' on 'GPUDevice': Failed to import texture from video element that doesn't have back resource. --- .../src/from_pixels_webgpu_test.ts | 18 +++-------- tfjs-core/src/ops/from_pixels_test.ts | 15 +++------ tfjs-core/src/test_util.ts | 32 +++++++++++++++++++ 3 files changed, 40 insertions(+), 25 deletions(-) diff --git a/tfjs-backend-webgpu/src/from_pixels_webgpu_test.ts b/tfjs-backend-webgpu/src/from_pixels_webgpu_test.ts index 787c719da80..7e485b1c93a 100644 --- a/tfjs-backend-webgpu/src/from_pixels_webgpu_test.ts +++ b/tfjs-backend-webgpu/src/from_pixels_webgpu_test.ts @@ -16,6 +16,7 @@ */ import * as tf from '@tensorflow/tfjs-core'; +import {test_util} from '@tensorflow/tfjs-core'; import {WebGPUBackend} from './backend_webgpu'; import {describeWebGPU} from './test_util'; @@ -37,26 +38,15 @@ describeWebGPU('fromPixels', () => { const textureManager = backend.textureManager; textureManager.dispose(); - const video = document.createElement('video'); const source = document.createElement('source'); source.src = // tslint:disable-next-line:max-line-length 'data:video/mp4;base64,AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1wNDEAAAAIZnJlZQAAAu1tZGF0AAACrQYF//+p3EXpvebZSLeWLNgg2SPu73gyNjQgLSBjb3JlIDE1NSByMjkwMSA3ZDBmZjIyIC0gSC4yNjQvTVBFRy00IEFWQyBjb2RlYyAtIENvcHlsZWZ0IDIwMDMtMjAxOCAtIGh0dHA6Ly93d3cudmlkZW9sYW4ub3JnL3gyNjQuaHRtbCAtIG9wdGlvbnM6IGNhYmFjPTEgcmVmPTMgZGVibG9jaz0xOjA6MCBhbmFseXNlPTB4MzoweDExMyBtZT1oZXggc3VibWU9NyBwc3k9MSBwc3lfcmQ9MS4wMDowLjAwIG1peGVkX3JlZj0xIG1lX3JhbmdlPTE2IGNocm9tYV9tZT0xIHRyZWxsaXM9MSA4eDhkY3Q9MSBjcW09MCBkZWFkem9uZT0yMSwxMSBmYXN0X3Bza2lwPTEgY2hyb21hX3FwX29mZnNldD0tMiB0aHJlYWRzPTMgbG9va2FoZWFkX3RocmVhZHM9MSBzbGljZWRfdGhyZWFkcz0wIG5yPTAgZGVjaW1hdGU9MSBpbnRlcmxhY2VkPTAgYmx1cmF5X2NvbXBhdD0wIGNvbnN0cmFpbmVkX2ludHJhPTAgYmZyYW1lcz0zIGJfcHlyYW1pZD0yIGJfYWRhcHQ9MSBiX2JpYXM9MCBkaXJlY3Q9MSB3ZWlnaHRiPTEgb3Blbl9nb3A9MCB3ZWlnaHRwPTIga2V5aW50PTI1MCBrZXlpbnRfbWluPTEgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVzaD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTI4LjAgcWNvbXA9MC42MCBxcG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAAwZYiEAD//8m+P5OXfBeLGOfKE3xkODvFZuBflHv/+VwJIta6cbpIo4ABLoKBaYTkTAAAC7m1vb3YAAABsbXZoZAAAAAAAAAAAAAAAAAAAA+gAAAPoAAEAAAEAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIAAAIYdHJhawAAAFx0a2hkAAAAAwAAAAAAAAAAAAAAAQAAAAAAAAPoAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAACgAAAAWgAAAAAAJGVkdHMAAAAcZWxzdAAAAAAAAAABAAAD6AAAAAAAAQAAAAABkG1kaWEAAAAgbWRoZAAAAAAAAAAAAAAAAAAAQAAAAEAAVcQAAAAAAC1oZGxyAAAAAAAAAAB2aWRlAAAAAAAAAAAAAAAAVmlkZW9IYW5kbGVyAAAAATttaW5mAAAAFHZtaGQAAAABAAAAAAAAAAAAAAAkZGluZgAAABxkcmVmAAAAAAAAAAEAAAAMdXJsIAAAAAEAAAD7c3RibAAAAJdzdHNkAAAAAAAAAAEAAACHYXZjMQAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAACgAFoASAAAAEgAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABj//wAAADFhdmNDAWQACv/hABhnZAAKrNlCjfkhAAADAAEAAAMAAg8SJZYBAAZo6+JLIsAAAAAYc3R0cwAAAAAAAAABAAAAAQAAQAAAAAAcc3RzYwAAAAAAAAABAAAAAQAAAAEAAAABAAAAFHN0c3oAAAAAAAAC5QAAAAEAAAAUc3RjbwAAAAAAAAABAAAAMAAAAGJ1ZHRhAAAAWm1ldGEAAAAAAAAAIWhkbHIAAAAAAAAAAG1kaXJhcHBsAAAAAAAAAAAAAAAALWlsc3QAAAAlqXRvbwAAAB1kYXRhAAAAAQAAAABMYXZmNTguMTIuMTAw'; source.type = 'video/mp4'; - video.appendChild(source); - document.body.appendChild(video); - - video.autoplay = true; - video.loop = true; - video.muted = true; - video.preload = 'auto'; - await video.play(); - // ensure video element to be loaded - if ('requestVideoFrameCallback' in video) { - // tslint:disable-next-line:no-any - await new Promise(go => (video as any).requestVideoFrameCallback(go)); - } + const video = await test_util.createVideoElement(source); + document.body.appendChild(video); + await test_util.play(video); { tf.env().set('WEBGPU_IMPORT_EXTERNAL_TEXTURE', true); diff --git a/tfjs-core/src/ops/from_pixels_test.ts b/tfjs-core/src/ops/from_pixels_test.ts index 1183e57667b..7cff8ed0a87 100644 --- a/tfjs-core/src/ops/from_pixels_test.ts +++ b/tfjs-core/src/ops/from_pixels_test.ts @@ -17,7 +17,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util'; -import {expectArraysClose, expectArraysEqual} from '../test_util'; +import {createVideoElement, expectArraysClose, expectArraysEqual, play} from '../test_util'; class MockContext { getImageData(x: number, y: number, width: number, height: number) { @@ -209,22 +209,15 @@ describeWithFlags('fromPixels', BROWSER_ENVS, () => { expect(data.length).toEqual(10 * 10 * 3); }); it('fromPixels for HTMLVideoElement', async () => { - const video = document.createElement('video'); - video.autoplay = true; const source = document.createElement('source'); source.src = // tslint:disable-next-line:max-line-length 'data:video/mp4;base64,AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1wNDEAAAAIZnJlZQAAAu1tZGF0AAACrQYF//+p3EXpvebZSLeWLNgg2SPu73gyNjQgLSBjb3JlIDE1NSByMjkwMSA3ZDBmZjIyIC0gSC4yNjQvTVBFRy00IEFWQyBjb2RlYyAtIENvcHlsZWZ0IDIwMDMtMjAxOCAtIGh0dHA6Ly93d3cudmlkZW9sYW4ub3JnL3gyNjQuaHRtbCAtIG9wdGlvbnM6IGNhYmFjPTEgcmVmPTMgZGVibG9jaz0xOjA6MCBhbmFseXNlPTB4MzoweDExMyBtZT1oZXggc3VibWU9NyBwc3k9MSBwc3lfcmQ9MS4wMDowLjAwIG1peGVkX3JlZj0xIG1lX3JhbmdlPTE2IGNocm9tYV9tZT0xIHRyZWxsaXM9MSA4eDhkY3Q9MSBjcW09MCBkZWFkem9uZT0yMSwxMSBmYXN0X3Bza2lwPTEgY2hyb21hX3FwX29mZnNldD0tMiB0aHJlYWRzPTMgbG9va2FoZWFkX3RocmVhZHM9MSBzbGljZWRfdGhyZWFkcz0wIG5yPTAgZGVjaW1hdGU9MSBpbnRlcmxhY2VkPTAgYmx1cmF5X2NvbXBhdD0wIGNvbnN0cmFpbmVkX2ludHJhPTAgYmZyYW1lcz0zIGJfcHlyYW1pZD0yIGJfYWRhcHQ9MSBiX2JpYXM9MCBkaXJlY3Q9MSB3ZWlnaHRiPTEgb3Blbl9nb3A9MCB3ZWlnaHRwPTIga2V5aW50PTI1MCBrZXlpbnRfbWluPTEgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVzaD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTI4LjAgcWNvbXA9MC42MCBxcG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAAwZYiEAD//8m+P5OXfBeLGOfKE3xkODvFZuBflHv/+VwJIta6cbpIo4ABLoKBaYTkTAAAC7m1vb3YAAABsbXZoZAAAAAAAAAAAAAAAAAAAA+gAAAPoAAEAAAEAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIAAAIYdHJhawAAAFx0a2hkAAAAAwAAAAAAAAAAAAAAAQAAAAAAAAPoAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAACgAAAAWgAAAAAAJGVkdHMAAAAcZWxzdAAAAAAAAAABAAAD6AAAAAAAAQAAAAABkG1kaWEAAAAgbWRoZAAAAAAAAAAAAAAAAAAAQAAAAEAAVcQAAAAAAC1oZGxyAAAAAAAAAAB2aWRlAAAAAAAAAAAAAAAAVmlkZW9IYW5kbGVyAAAAATttaW5mAAAAFHZtaGQAAAABAAAAAAAAAAAAAAAkZGluZgAAABxkcmVmAAAAAAAAAAEAAAAMdXJsIAAAAAEAAAD7c3RibAAAAJdzdHNkAAAAAAAAAAEAAACHYXZjMQAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAACgAFoASAAAAEgAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABj//wAAADFhdmNDAWQACv/hABhnZAAKrNlCjfkhAAADAAEAAAMAAg8SJZYBAAZo6+JLIsAAAAAYc3R0cwAAAAAAAAABAAAAAQAAQAAAAAAcc3RzYwAAAAAAAAABAAAAAQAAAAEAAAABAAAAFHN0c3oAAAAAAAAC5QAAAAEAAAAUc3RjbwAAAAAAAAABAAAAMAAAAGJ1ZHRhAAAAWm1ldGEAAAAAAAAAIWhkbHIAAAAAAAAAAG1kaXJhcHBsAAAAAAAAAAAAAAAALWlsc3QAAAAlqXRvbwAAAB1kYXRhAAAAAQAAAABMYXZmNTguMTIuMTAw'; source.type = 'video/mp4'; - video.appendChild(source); - document.body.appendChild(video); - // On mobile safari the ready state is ready immediately. - if (video.readyState < 2) { - await new Promise(resolve => { - video.addEventListener('loadeddata', () => resolve(video)); - }); - } + const video = await createVideoElement(source); + document.body.appendChild(video); + await play(video); const res = tf.browser.fromPixels(video); expect(res.shape).toEqual([90, 160, 3]); diff --git a/tfjs-core/src/test_util.ts b/tfjs-core/src/test_util.ts index 914950f08b0..3e82c630700 100644 --- a/tfjs-core/src/test_util.ts +++ b/tfjs-core/src/test_util.ts @@ -190,3 +190,35 @@ export function encodeStrings(a: RecursiveArray<{}>): } return a as RecursiveArray; } + +/** Creates an HTMLVideoElement with autoplay-friendly default settings. */ +export function createVideoElement(source: HTMLSourceElement): + Promise { + const video = document.createElement('video'); + if ('playsInline' in video) { + // tslint:disable-next-line:no-any + (video as any).playsInline = true; + } + video.muted = true; + video.loop = true; + video.style.position = 'fixed'; + video.style.left = '0px'; + video.style.top = '0px'; + + video.preload = 'auto'; + video.appendChild(source); + return new Promise(resolve => { + video.addEventListener('loadeddata', _ => resolve(video)); + video.load(); + }); +} + +export async function play(video: HTMLVideoElement) { + await video.play(); + if ('requestVideoFrameCallback' in video) { + await new Promise(resolve => { + // tslint:disable-next-line:no-any + (video as any).requestVideoFrameCallback(resolve); + }); + } +} From a366cc24edb4cc35d4889fe7fd441389d2beedd8 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 17 Aug 2022 14:24:39 +0800 Subject: [PATCH 2/3] webgpu: enlarge the splitted dimInner size (#6755) --- .../src/matmul_packed_webgpu.ts | 18 +++++++++++------- .../src/matmul_splitK_webgpu.ts | 9 +++++---- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts index 3e5fcfcb6d1..ac9a0a36632 100644 --- a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts @@ -160,7 +160,7 @@ const calculateResultSnippet = export function makeMatMulPackedVec4Source( workPerThread: number[], workGroupSize: [number, number, number], - transposeA = false, tileInner = 32, splitK = false, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, isVectorA = false): string { const tileAOuter = workGroupSize[1] * workPerThread[1]; const tileBOuter = workGroupSize[0] * workPerThread[0]; @@ -209,8 +209,10 @@ export function makeMatMulPackedVec4Source( let batch = ${splitK ? '0' : 'i32(globalId.z)'}; let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? '1' : '(uniforms.dimInner - 1) / TileInner + 1'}; - var kStart = ${splitK ? 'i32(globalId.z) * TileInner' : '0'}; + let numTiles = ${ + splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : + '(uniforms.dimInner - 1) / TileInner + 1'}; + var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, RowPerThread>; @@ -281,7 +283,8 @@ const readDataFromSubASnippet = (transposeA: boolean) => { export function makeMatMulPackedSource( workPerThread: number[], workGroupSize: [number, number, number], - transposeA = false, tileInner = 32, splitK = false): string { + transposeA = false, tileInner = 32, splitK = false, + splitedDimInner = 32): string { const tileAOuter = workPerThread[1] * workGroupSize[1]; const tileBOuter = workPerThread[0] * workGroupSize[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -323,8 +326,9 @@ export function makeMatMulPackedSource( let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; let numTiles = ${ - splitK ? '1' : '(uniforms.dimInner - 1) / TileInner + 1'}; - var kStart = ${splitK ? 'i32(globalId.z) * TileInner' : '0'}; + splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : + '(uniforms.dimInner - 1) / TileInner + 1'}; + var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, RowPerThread>; @@ -565,7 +569,7 @@ export class MatMulPackedProgram implements WebGPUProgram { this.isVec4 ? makeMatMulPackedVec4Source( this.elementsPerThread, this.workGroupSize, this.transposeA, - this.tileInner, false, this.isVectorA) : + this.tileInner, false, null, this.isVectorA) : (this.isVectorA ? makeVectorMatrixProductSource( this.workGroupSize, this.transposeA) : makeMatMulPackedSource( diff --git a/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts b/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts index ac6ea700f64..0e443f7a63b 100644 --- a/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts @@ -37,7 +37,7 @@ export class MatMulSplitKProgram implements WebGPUProgram { batchAEqualOne: boolean; batchBEqualOne: boolean; isVec4 = false; - tileInner = 32; + splitedDimInner = 128; constructor( outputShape: [number, number, number], dimInner: number, @@ -51,7 +51,8 @@ export class MatMulSplitKProgram implements WebGPUProgram { this.isVec4 = (transposeA && this.outputShape[1] % 4 === 0 || !transposeA && dimInner % 4 === 0) && this.outputShape[2] % 4 === 0; - this.elementsPerThread = [4, 4, this.tileInner]; + this.elementsPerThread = [4, 4, this.splitedDimInner]; + if (!this.isVec4) { if (this.outputShape[1] < 16) { this.elementsPerThread[1] = 1; @@ -119,10 +120,10 @@ export class MatMulSplitKProgram implements WebGPUProgram { ${ this.isVec4 ? makeMatMulPackedVec4Source( this.elementsPerThread, this.workGroupSize, - this.transposeA, this.tileInner, true) : + this.transposeA, 32, true, this.splitedDimInner) : makeMatMulPackedSource( this.elementsPerThread, this.workGroupSize, - this.transposeA, this.tileInner, true)} + this.transposeA, 32, true, this.splitedDimInner)} `; return userCode; } From d515f4da8eedb2333f7e1f2d1b78b9ec18297d4b Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Wed, 17 Aug 2022 15:48:45 +0800 Subject: [PATCH 3/3] [webgpu] s/ScatterOptimizedProgram/ScatterProgram/g (#6761) --- tfjs-backend-webgpu/src/kernels/ScatterNd.ts | 4 +- .../src/kernels/SparseToDense.ts | 8 +- .../src/scatter_optimized_webgpu.ts | 142 ---------------- tfjs-backend-webgpu/src/scatter_webgpu.ts | 157 +++++++++++------- tfjs-backend-webgpu/src/setup_test.ts | 7 - 5 files changed, 102 insertions(+), 216 deletions(-) delete mode 100644 tfjs-backend-webgpu/src/scatter_optimized_webgpu.ts diff --git a/tfjs-backend-webgpu/src/kernels/ScatterNd.ts b/tfjs-backend-webgpu/src/kernels/ScatterNd.ts index ebad7ee2ae1..c72a8fc1572 100644 --- a/tfjs-backend-webgpu/src/kernels/ScatterNd.ts +++ b/tfjs-backend-webgpu/src/kernels/ScatterNd.ts @@ -18,10 +18,10 @@ import {backend_util, KernelConfig, KernelFunc, ScatterNd, ScatterNdAttrs, ScatterNdInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {WebGPUBackend} from '../backend_webgpu'; +import {ScatterProgram} from '../scatter_webgpu'; import {fill} from './Fill'; import {reshape} from './Reshape'; -import {ScatterOptimizedProgram} from '../scatter_optimized_webgpu'; export function scatterNd(args: { inputs: ScatterNdInputs, @@ -54,7 +54,7 @@ export function scatterNd(args: { {type: 'int32', data: [sliceRank]}, {type: 'int32', data: strides}, {type: 'int32', data: [size]} ]; - const program = new ScatterOptimizedProgram( + const program = new ScatterProgram( flattenX.shape, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape, type); const res = backend.runWebGPUProgram( diff --git a/tfjs-backend-webgpu/src/kernels/SparseToDense.ts b/tfjs-backend-webgpu/src/kernels/SparseToDense.ts index a4015e7cbb6..35a4ca8a430 100644 --- a/tfjs-backend-webgpu/src/kernels/SparseToDense.ts +++ b/tfjs-backend-webgpu/src/kernels/SparseToDense.ts @@ -19,7 +19,7 @@ import {backend_util, KernelConfig, KernelFunc, Rank, SparseToDense, SparseToDen import {WebGPUBackend} from '../backend_webgpu'; import {scatterImplCPU} from '../kernel_utils/shared'; -import {ScatterOptimizedProgram} from '../scatter_optimized_webgpu'; +import {ScatterProgram} from '../scatter_webgpu'; import {identity} from './Identity'; import {reshape} from './Reshape'; @@ -89,7 +89,7 @@ export function sparseToDense(args: { break; case 1: if (true) { - const program = new ScatterOptimizedProgram( + const program = new ScatterProgram( [numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length, $sparseValues.shape.length, strides, flattenShape, type, sumDupeIndices); @@ -101,7 +101,7 @@ export function sparseToDense(args: { default: if (true) { // First replace the default value with 0 at indices. - const program = new ScatterOptimizedProgram( + const program = new ScatterProgram( [numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length, zero.shape.length, strides, flattenShape, type, sumDupeIndices); backend.runWebGPUProgram( @@ -109,7 +109,7 @@ export function sparseToDense(args: { } { // Then replace 0 with the (sum of) sparse value(s) at indices. - const program = new ScatterOptimizedProgram( + const program = new ScatterProgram( [numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length, $sparseValues.shape.length, strides, flattenShape, type); backend.runWebGPUProgram( diff --git a/tfjs-backend-webgpu/src/scatter_optimized_webgpu.ts b/tfjs-backend-webgpu/src/scatter_optimized_webgpu.ts deleted file mode 100644 index 0ead77b8bae..00000000000 --- a/tfjs-backend-webgpu/src/scatter_optimized_webgpu.ts +++ /dev/null @@ -1,142 +0,0 @@ -/** - * @license - * Copyright 2021 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {DataType} from '@tensorflow/tfjs-core'; -import {getCoordsDataType, getMainHeaderAndGlobalIndexString, mapToWgslTypes, WebGPUProgram} from './webgpu_program'; -import {computeDispatch, flatDispatchLayout} from './webgpu_util'; - -export class ScatterOptimizedProgram implements WebGPUProgram { - variableNames = ['updates', 'indices']; - uniforms: string; - outputShape: number[]; - sumDupeIndices: boolean; - shaderKey: string; - dispatchLayout: {x: number[]}; - dispatch: [number, number, number]; - workGroupSize: [number, number, number] = [64, 1, 1]; - updatesRank: number; - indicesRank: number; - sliceDimGreaterThanOne: boolean; - atomic = true; - type: DataType; - - constructor( - flattenXShape: number[], sliceDim: number, indicesRank: number, - updatesRank: number, strides: number[], shape: number[], - outputDtype: DataType, sumDupeIndices = true) { - this.outputShape = shape; - this.type = outputDtype; - this.sumDupeIndices = sumDupeIndices; - this.dispatchLayout = flatDispatchLayout(flattenXShape); - // Dispatching based on |updates| shape instead of output shape. - this.dispatch = - computeDispatch(this.dispatchLayout, flattenXShape, this.workGroupSize); - this.sliceDimGreaterThanOne = sliceDim > 1; - this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${ - this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`; - const stridesType = getCoordsDataType(strides.length); - this.uniforms = `sliceDim : i32, strides: ${stridesType}, size: i32,`; - this.updatesRank = updatesRank; - this.indicesRank = indicesRank; - } - - getUserCode(): string { - let indicesString = ''; - if (this.indicesRank === 1) { - indicesString = 'coords[0]'; - } else if (this.indicesRank === 2) { - indicesString = 'coords[0], j'; - } - const indicesSnippet = `getIndices(${indicesString})`; - - const strideString = this.sliceDimGreaterThanOne ? 'uniforms.strides[j]' : - 'uniforms.strides'; - - let outCoordsString = ''; - let getUpdatesCoordsFromFlatIndex = ''; - if (this.dispatchLayout.x.length === 1) { - outCoordsString = 'flattenedIndex'; - getUpdatesCoordsFromFlatIndex = ` - fn getUpdatesCoordsFromFlatIndex(index : i32) -> i32 { - return index; - } - `; - } else if (this.dispatchLayout.x.length === 2) { - outCoordsString = 'vec2(flattenedIndex, coords[1])'; - getUpdatesCoordsFromFlatIndex = ` - fn getUpdatesCoordsFromFlatIndex(index : i32) -> vec2 { - // N.B. |updates| could be a scalar tensor, conceptually representing a - // 2D tensor with all values equal to that. By design, its size must be - // the same as |outShape[1]| in one dimension, and |indicesShape[0]| - // gives the other. - let sliceSize = uniforms.outShape[1]; - let d0 = index / sliceSize; - let d1 = index - d0 * sliceSize; - return vec2(d0, d1); - } - `; - } - const updatesString = - Array.from({length: this.updatesRank}, (_, idx) => `coords[${idx}]`); - const updatesSnippet = `getUpdates(${updatesString.join(', ')})`; - - const atomicRMW = (ptr: string, val: string) => { - let atomicAddSnippet = `atomicAdd(${ptr}, bitcast(${val}))`; - if (this.type === 'float32') { - atomicAddSnippet = ` - { - var oldBits = 0; - var newBits = bitcast(${val}); - loop { - let info = atomicCompareExchangeWeak(${ptr}, oldBits, newBits); - if (info.exchanged) { - break; - } - oldBits = info.old_value; - let oldValue = bitcast(oldBits); - let newValue = oldValue + (${val}); - newBits = bitcast(newValue); - } - } - `; - } - const atomicStoreSnippet = `atomicStore(${ptr}, bitcast(${val}));`; - return this.sumDupeIndices ? atomicAddSnippet : atomicStoreSnippet; - }; - - const userCode = ` - ${getUpdatesCoordsFromFlatIndex} - - ${getMainHeaderAndGlobalIndexString()} - - if (index < uniforms.size) { - let coords = getUpdatesCoordsFromFlatIndex(index); - var flattenedIndex = 0; - for (var j = 0; j < uniforms.sliceDim; j = j + 1) { - let indexInside = i32(round(${indicesSnippet})); - flattenedIndex = flattenedIndex + indexInside * ${strideString}; - } - let updateValue = - ${mapToWgslTypes(this.type, false)}(${updatesSnippet}); - let flatIndex = getOutputIndexFromCoords(${outCoordsString}); - - ${atomicRMW('&result[flatIndex]', 'updateValue')}; - } - }`; - return userCode; - } -} diff --git a/tfjs-backend-webgpu/src/scatter_webgpu.ts b/tfjs-backend-webgpu/src/scatter_webgpu.ts index 6c3749c1396..cf65925e5af 100644 --- a/tfjs-backend-webgpu/src/scatter_webgpu.ts +++ b/tfjs-backend-webgpu/src/scatter_webgpu.ts @@ -15,91 +15,126 @@ * ============================================================================= */ -import {getCoordsDataType, getMainHeaderAndGlobalIndexString, WebGPUProgram} from './webgpu_program'; +import {DataType} from '@tensorflow/tfjs-core'; +import {getCoordsDataType, getMainHeaderAndGlobalIndexString, mapToWgslTypes, WebGPUProgram} from './webgpu_program'; import {computeDispatch, flatDispatchLayout} from './webgpu_util'; export class ScatterProgram implements WebGPUProgram { - variableNames = ['updates', 'indices', 'defaultValue']; + variableNames = ['updates', 'indices']; uniforms: string; outputShape: number[]; + sumDupeIndices: boolean; shaderKey: string; dispatchLayout: {x: number[]}; dispatch: [number, number, number]; workGroupSize: [number, number, number] = [64, 1, 1]; - workPerThread = 4; - size = true; - indicesSnippet: string; - strideString: string; - updatesSnippet: string; + updatesRank: number; + indicesRank: number; + sliceDimGreaterThanOne: boolean; + atomic = true; + type: DataType; constructor( - updateSize: number, sliceDim: number, indicesRank: number, + flattenXShape: number[], sliceDim: number, indicesRank: number, updatesRank: number, strides: number[], shape: number[], - summingDupeIndex = true) { + outputDtype: DataType, sumDupeIndices = true) { this.outputShape = shape; - this.dispatchLayout = flatDispatchLayout(this.outputShape); - this.dispatch = computeDispatch( - this.dispatchLayout, this.outputShape, this.workGroupSize, - [this.workPerThread, 1, 1]); - const sliceDimGreaterThanOne = sliceDim > 1; - this.shaderKey = - `scatter_${indicesRank}_${updatesRank}_${sliceDimGreaterThanOne}`; + this.type = outputDtype; + this.sumDupeIndices = sumDupeIndices; + this.dispatchLayout = flatDispatchLayout(flattenXShape); + // Dispatching based on |updates| shape instead of output shape. + this.dispatch = + computeDispatch(this.dispatchLayout, flattenXShape, this.workGroupSize); + this.sliceDimGreaterThanOne = sliceDim > 1; + this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${ + this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`; const stridesType = getCoordsDataType(strides.length); - this.uniforms = - `updateSize : i32, sliceDim : i32, strides: ${stridesType},`; + this.uniforms = `sliceDim : i32, strides: ${stridesType}, size: i32,`; + this.updatesRank = updatesRank; + this.indicesRank = indicesRank; + } + + getUserCode(): string { let indicesString = ''; - if (indicesRank === 1) { - indicesString = 'i'; - } else if (indicesRank === 2) { - indicesString = 'i, j'; + if (this.indicesRank === 1) { + indicesString = 'coords[0]'; + } else if (this.indicesRank === 2) { + indicesString = 'coords[0], j'; } - this.indicesSnippet = `getIndices(${indicesString})`; + const indicesSnippet = `getIndices(${indicesString})`; + + const strideString = this.sliceDimGreaterThanOne ? 'uniforms.strides[j]' : + 'uniforms.strides'; - let updatesString = ''; - if (updatesRank === 1) { - updatesString = 'i'; - } else if (updatesRank === 2) { - updatesString = 'i, coords[1]'; + let outCoordsString = ''; + let getUpdatesCoordsFromFlatIndex = ''; + if (this.dispatchLayout.x.length === 1) { + outCoordsString = 'flattenedIndex'; + getUpdatesCoordsFromFlatIndex = ` + fn getUpdatesCoordsFromFlatIndex(index : i32) -> i32 { + return index; + } + `; + } else if (this.dispatchLayout.x.length === 2) { + outCoordsString = 'vec2(flattenedIndex, coords[1])'; + getUpdatesCoordsFromFlatIndex = ` + fn getUpdatesCoordsFromFlatIndex(index : i32) -> vec2 { + // N.B. |updates| could be a scalar tensor, conceptually representing a + // 2D tensor with all values equal to that. By design, its size must be + // the same as |outShape[1]| in one dimension, and |indicesShape[0]| + // gives the other. + let sliceSize = uniforms.outShape[1]; + let d0 = index / sliceSize; + let d1 = index - d0 * sliceSize; + return vec2(d0, d1); + } + `; } - this.updatesSnippet = `getUpdates(${updatesString})`; + const updatesString = + Array.from({length: this.updatesRank}, (_, idx) => `coords[${idx}]`); + const updatesSnippet = `getUpdates(${updatesString.join(', ')})`; - this.strideString = - sliceDimGreaterThanOne ? 'uniforms.strides[j]' : 'uniforms.strides'; - } + const atomicRMW = (ptr: string, val: string) => { + let atomicAddSnippet = `atomicAdd(${ptr}, bitcast(${val}))`; + if (this.type === 'float32') { + atomicAddSnippet = ` + { + var oldBits = 0; + var newBits = bitcast(${val}); + loop { + let info = atomicCompareExchangeWeak(${ptr}, oldBits, newBits); + if (info.exchanged) { + break; + } + oldBits = info.old_value; + let oldValue = bitcast(oldBits); + let newValue = oldValue + (${val}); + newBits = bitcast(newValue); + } + } + `; + } + const atomicStoreSnippet = `atomicStore(${ptr}, bitcast(${val}));`; + return this.sumDupeIndices ? atomicAddSnippet : atomicStoreSnippet; + }; - getUserCode(): string { const userCode = ` + ${getUpdatesCoordsFromFlatIndex} + ${getMainHeaderAndGlobalIndexString()} - let globalIndex = index * ${this.workPerThread}; - if (globalIndex < uniforms.size) { - var sum = vec4(0.0); - var found = vec4(false); - for (var i = 0; i < uniforms.updateSize; i = i + 1) { - var flattenedIndex = 0; - for (var j = 0; j < uniforms.sliceDim; j = j + 1) { - let indexInside = i32(round(${this.indicesSnippet})); - flattenedIndex = flattenedIndex + indexInside * ${ - this.strideString}; - } - for (var innerIndex = 0; innerIndex < ${ - this.workPerThread}; innerIndex = innerIndex + 1) { - let curIndex = globalIndex + innerIndex; - let coords = getCoordsFromIndex(curIndex); - if (flattenedIndex == coords[0]) { - sum[innerIndex] = sum[innerIndex] + ${this.updatesSnippet}; - found[innerIndex] = true; - } - } - } - for (var innerIndex = 0; innerIndex < ${ - this.workPerThread}; innerIndex = innerIndex + 1) { - let curIndex = globalIndex + innerIndex; - if (curIndex < uniforms.size) - { - setOutputAtIndex(curIndex, mix(getDefaultValue(), sum[innerIndex], f32(found[innerIndex]))); - } + if (index < uniforms.size) { + let coords = getUpdatesCoordsFromFlatIndex(index); + var flattenedIndex = 0; + for (var j = 0; j < uniforms.sliceDim; j = j + 1) { + let indexInside = i32(round(${indicesSnippet})); + flattenedIndex = flattenedIndex + indexInside * ${strideString}; } + let updateValue = + ${mapToWgslTypes(this.type, false)}(${updatesSnippet}); + let flatIndex = getOutputIndexFromCoords(${outCoordsString}); + + ${atomicRMW('&result[flatIndex]', 'updateValue')}; } }`; return userCode; diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 58e4ee0fd7d..c0fe2a4f581 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -221,13 +221,6 @@ const TEST_FILTERS: TestFilter[] = [ 'accepts a tensor-like object', ] }, - { - startsWith: 'sparseToDense ', - excludes: [ - // TODO: Fix 0-sized buffer binding on WebGPU - '0-sized', // Not yet implemented. - ] - }, { startsWith: 'square ', excludes: [