Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/tensorflow/tfjs into jax2…
Browse files Browse the repository at this point in the history
…tfjs
  • Loading branch information
marcvanzee committed Aug 17, 2022
2 parents 1dac6d1 + d515f4d commit 1b63e5d
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 252 deletions.
18 changes: 4 additions & 14 deletions tfjs-backend-webgpu/src/from_pixels_webgpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

import * as tf from '@tensorflow/tfjs-core';
import {test_util} from '@tensorflow/tfjs-core';
import {WebGPUBackend} from './backend_webgpu';
import {describeWebGPU} from './test_util';

Expand All @@ -37,26 +38,15 @@ describeWebGPU('fromPixels', () => {
const textureManager = backend.textureManager;
textureManager.dispose();

const video = document.createElement('video');
const source = document.createElement('source');
source.src =
// tslint:disable-next-line:max-line-length
'data:video/mp4;base64,AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1wNDEAAAAIZnJlZQAAAu1tZGF0AAACrQYF//+p3EXpvebZSLeWLNgg2SPu73gyNjQgLSBjb3JlIDE1NSByMjkwMSA3ZDBmZjIyIC0gSC4yNjQvTVBFRy00IEFWQyBjb2RlYyAtIENvcHlsZWZ0IDIwMDMtMjAxOCAtIGh0dHA6Ly93d3cudmlkZW9sYW4ub3JnL3gyNjQuaHRtbCAtIG9wdGlvbnM6IGNhYmFjPTEgcmVmPTMgZGVibG9jaz0xOjA6MCBhbmFseXNlPTB4MzoweDExMyBtZT1oZXggc3VibWU9NyBwc3k9MSBwc3lfcmQ9MS4wMDowLjAwIG1peGVkX3JlZj0xIG1lX3JhbmdlPTE2IGNocm9tYV9tZT0xIHRyZWxsaXM9MSA4eDhkY3Q9MSBjcW09MCBkZWFkem9uZT0yMSwxMSBmYXN0X3Bza2lwPTEgY2hyb21hX3FwX29mZnNldD0tMiB0aHJlYWRzPTMgbG9va2FoZWFkX3RocmVhZHM9MSBzbGljZWRfdGhyZWFkcz0wIG5yPTAgZGVjaW1hdGU9MSBpbnRlcmxhY2VkPTAgYmx1cmF5X2NvbXBhdD0wIGNvbnN0cmFpbmVkX2ludHJhPTAgYmZyYW1lcz0zIGJfcHlyYW1pZD0yIGJfYWRhcHQ9MSBiX2JpYXM9MCBkaXJlY3Q9MSB3ZWlnaHRiPTEgb3Blbl9nb3A9MCB3ZWlnaHRwPTIga2V5aW50PTI1MCBrZXlpbnRfbWluPTEgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVzaD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTI4LjAgcWNvbXA9MC42MCBxcG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAAwZYiEAD//8m+P5OXfBeLGOfKE3xkODvFZuBflHv/+VwJIta6cbpIo4ABLoKBaYTkTAAAC7m1vb3YAAABsbXZoZAAAAAAAAAAAAAAAAAAAA+gAAAPoAAEAAAEAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIAAAIYdHJhawAAAFx0a2hkAAAAAwAAAAAAAAAAAAAAAQAAAAAAAAPoAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAACgAAAAWgAAAAAAJGVkdHMAAAAcZWxzdAAAAAAAAAABAAAD6AAAAAAAAQAAAAABkG1kaWEAAAAgbWRoZAAAAAAAAAAAAAAAAAAAQAAAAEAAVcQAAAAAAC1oZGxyAAAAAAAAAAB2aWRlAAAAAAAAAAAAAAAAVmlkZW9IYW5kbGVyAAAAATttaW5mAAAAFHZtaGQAAAABAAAAAAAAAAAAAAAkZGluZgAAABxkcmVmAAAAAAAAAAEAAAAMdXJsIAAAAAEAAAD7c3RibAAAAJdzdHNkAAAAAAAAAAEAAACHYXZjMQAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAACgAFoASAAAAEgAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABj//wAAADFhdmNDAWQACv/hABhnZAAKrNlCjfkhAAADAAEAAAMAAg8SJZYBAAZo6+JLIsAAAAAYc3R0cwAAAAAAAAABAAAAAQAAQAAAAAAcc3RzYwAAAAAAAAABAAAAAQAAAAEAAAABAAAAFHN0c3oAAAAAAAAC5QAAAAEAAAAUc3RjbwAAAAAAAAABAAAAMAAAAGJ1ZHRhAAAAWm1ldGEAAAAAAAAAIWhkbHIAAAAAAAAAAG1kaXJhcHBsAAAAAAAAAAAAAAAALWlsc3QAAAAlqXRvbwAAAB1kYXRhAAAAAQAAAABMYXZmNTguMTIuMTAw';
source.type = 'video/mp4';
video.appendChild(source);
document.body.appendChild(video);

video.autoplay = true;
video.loop = true;
video.muted = true;
video.preload = 'auto';
await video.play();

// ensure video element to be loaded
if ('requestVideoFrameCallback' in video) {
// tslint:disable-next-line:no-any
await new Promise(go => (video as any).requestVideoFrameCallback(go));
}
const video = await test_util.createVideoElement(source);
document.body.appendChild(video);
await test_util.play(video);

{
tf.env().set('WEBGPU_IMPORT_EXTERNAL_TEXTURE', true);
Expand Down
4 changes: 2 additions & 2 deletions tfjs-backend-webgpu/src/kernels/ScatterNd.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import {backend_util, KernelConfig, KernelFunc, ScatterNd, ScatterNdAttrs, ScatterNdInputs, TensorInfo, util} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';
import {ScatterProgram} from '../scatter_webgpu';

import {fill} from './Fill';
import {reshape} from './Reshape';
import {ScatterOptimizedProgram} from '../scatter_optimized_webgpu';

export function scatterNd(args: {
inputs: ScatterNdInputs,
Expand Down Expand Up @@ -54,7 +54,7 @@ export function scatterNd(args: {
{type: 'int32', data: [sliceRank]}, {type: 'int32', data: strides},
{type: 'int32', data: [size]}
];
const program = new ScatterOptimizedProgram(
const program = new ScatterProgram(
flattenX.shape, sliceRank, flattenIndices.shape.length,
flattenX.shape.length, strides, flattenShape, type);
const res = backend.runWebGPUProgram(
Expand Down
8 changes: 4 additions & 4 deletions tfjs-backend-webgpu/src/kernels/SparseToDense.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {backend_util, KernelConfig, KernelFunc, Rank, SparseToDense, SparseToDen

import {WebGPUBackend} from '../backend_webgpu';
import {scatterImplCPU} from '../kernel_utils/shared';
import {ScatterOptimizedProgram} from '../scatter_optimized_webgpu';
import {ScatterProgram} from '../scatter_webgpu';

import {identity} from './Identity';
import {reshape} from './Reshape';
Expand Down Expand Up @@ -89,7 +89,7 @@ export function sparseToDense(args: {
break;
case 1:
if (true) {
const program = new ScatterOptimizedProgram(
const program = new ScatterProgram(
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
$sparseValues.shape.length, strides, flattenShape, type,
sumDupeIndices);
Expand All @@ -101,15 +101,15 @@ export function sparseToDense(args: {
default:
if (true) {
// First replace the default value with 0 at indices.
const program = new ScatterOptimizedProgram(
const program = new ScatterProgram(
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
zero.shape.length, strides, flattenShape, type, sumDupeIndices);
backend.runWebGPUProgram(
program, [zero, $sparseIndices], type, uniformData, $denseValues);
}
{
// Then replace 0 with the (sum of) sparse value(s) at indices.
const program = new ScatterOptimizedProgram(
const program = new ScatterProgram(
[numUpdates, sliceSize], sliceRank, $sparseIndices.shape.length,
$sparseValues.shape.length, strides, flattenShape, type);
backend.runWebGPUProgram(
Expand Down
18 changes: 11 additions & 7 deletions tfjs-backend-webgpu/src/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ const calculateResultSnippet =

export function makeMatMulPackedVec4Source(
workPerThread: number[], workGroupSize: [number, number, number],
transposeA = false, tileInner = 32, splitK = false,
transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32,
isVectorA = false): string {
const tileAOuter = workGroupSize[1] * workPerThread[1];
const tileBOuter = workGroupSize[0] * workPerThread[0];
Expand Down Expand Up @@ -209,8 +209,10 @@ export function makeMatMulPackedVec4Source(
let batch = ${splitK ? '0' : 'i32(globalId.z)'};
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
let numTiles = ${splitK ? '1' : '(uniforms.dimInner - 1) / TileInner + 1'};
var kStart = ${splitK ? 'i32(globalId.z) * TileInner' : '0'};
let numTiles = ${
splitK ? `${Math.ceil(splitedDimInner / tileInner)}` :
'(uniforms.dimInner - 1) / TileInner + 1'};
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
var acc: array<vec4<f32>, RowPerThread>;
Expand Down Expand Up @@ -281,7 +283,8 @@ const readDataFromSubASnippet = (transposeA: boolean) => {

export function makeMatMulPackedSource(
workPerThread: number[], workGroupSize: [number, number, number],
transposeA = false, tileInner = 32, splitK = false): string {
transposeA = false, tileInner = 32, splitK = false,
splitedDimInner = 32): string {
const tileAOuter = workPerThread[1] * workGroupSize[1];
const tileBOuter = workPerThread[0] * workGroupSize[0];
const tileAWidth = transposeA ? tileAOuter : tileInner;
Expand Down Expand Up @@ -323,8 +326,9 @@ export function makeMatMulPackedSource(
let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
let numTiles = ${
splitK ? '1' : '(uniforms.dimInner - 1) / TileInner + 1'};
var kStart = ${splitK ? 'i32(globalId.z) * TileInner' : '0'};
splitK ? `${Math.ceil(splitedDimInner / tileInner)}` :
'(uniforms.dimInner - 1) / TileInner + 1'};
var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
var acc : array<array<f32, ColPerThread>, RowPerThread>;
Expand Down Expand Up @@ -565,7 +569,7 @@ export class MatMulPackedProgram implements WebGPUProgram {
this.isVec4 ?
makeMatMulPackedVec4Source(
this.elementsPerThread, this.workGroupSize, this.transposeA,
this.tileInner, false, this.isVectorA) :
this.tileInner, false, null, this.isVectorA) :
(this.isVectorA ? makeVectorMatrixProductSource(
this.workGroupSize, this.transposeA) :
makeMatMulPackedSource(
Expand Down
9 changes: 5 additions & 4 deletions tfjs-backend-webgpu/src/matmul_splitK_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export class MatMulSplitKProgram implements WebGPUProgram {
batchAEqualOne: boolean;
batchBEqualOne: boolean;
isVec4 = false;
tileInner = 32;
splitedDimInner = 128;

constructor(
outputShape: [number, number, number], dimInner: number,
Expand All @@ -51,7 +51,8 @@ export class MatMulSplitKProgram implements WebGPUProgram {
this.isVec4 = (transposeA && this.outputShape[1] % 4 === 0 ||
!transposeA && dimInner % 4 === 0) &&
this.outputShape[2] % 4 === 0;
this.elementsPerThread = [4, 4, this.tileInner];
this.elementsPerThread = [4, 4, this.splitedDimInner];

if (!this.isVec4) {
if (this.outputShape[1] < 16) {
this.elementsPerThread[1] = 1;
Expand Down Expand Up @@ -119,10 +120,10 @@ export class MatMulSplitKProgram implements WebGPUProgram {
${
this.isVec4 ? makeMatMulPackedVec4Source(
this.elementsPerThread, this.workGroupSize,
this.transposeA, this.tileInner, true) :
this.transposeA, 32, true, this.splitedDimInner) :
makeMatMulPackedSource(
this.elementsPerThread, this.workGroupSize,
this.transposeA, this.tileInner, true)}
this.transposeA, 32, true, this.splitedDimInner)}
`;
return userCode;
}
Expand Down
142 changes: 0 additions & 142 deletions tfjs-backend-webgpu/src/scatter_optimized_webgpu.ts

This file was deleted.

Loading

0 comments on commit 1b63e5d

Please sign in to comment.