Skip to content

Commit

Permalink
Move to platform_browser
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang Gu committed Aug 6, 2022
1 parent 2502a86 commit b87bf71
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 106 deletions.
6 changes: 5 additions & 1 deletion tfjs-backend-webgl/src/gpgpu_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,15 @@ export class GPGPUContext {
return;
}
// Start a new loop that polls.
let scheduleFn = undefined;
if ('setTimeoutCustom' in env().platform) {
scheduleFn = env().platform.setTimeoutCustom.bind(env().platform);
}
util.repeatedTry(() => {
this.pollItems();
// End the loop if no more items to poll.
return this.itemsToPoll.length === 0;
});
}, () => 0, null, scheduleFn);
}

private bindTextureToFrameBuffer(texture: WebGLTexture) {
Expand Down
4 changes: 2 additions & 2 deletions tfjs-core/src/flags.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ ENV.registerFlag('ENGINE_COMPILE_ONLY', () => false);
/** Whether to enable canvas2d willReadFrequently for GPU backends */
ENV.registerFlag('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', () => false);

/** Whether to use setTimeoutWPM to replace setTimeout in WebGL data() */
ENV.registerFlag('USE_SETTIMEOUTWPM', () => false);
/** Whether to use setTimeoutCustom */
ENV.registerFlag('USE_SETTIMEOUTCUSTOM', () => false);
2 changes: 2 additions & 0 deletions tfjs-core/src/platforms/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,6 @@ export interface Platform {
encode(text: string, encoding: string): Uint8Array;
/** Decode the provided bytes into a string using the provided encoding. */
decode(bytes: Uint8Array, encoding: string): string;

setTimeoutCustom?(functionRef: Function, delay: number): void;
}
39 changes: 39 additions & 0 deletions tfjs-core/src/platforms/platform_browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ export class PlatformBrowser implements Platform {
// https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder
private textEncoder: TextEncoder;

// For setTimeoutCustom
private messageName = 'setTimeoutCustom';
private functionRefs: Function[] = [];
private handledMessageCount = 0;
private hasEventListener = false;

fetch(path: string, init?: RequestInit): Promise<Response> {
return fetch(path, init);
}
Expand All @@ -50,6 +56,39 @@ export class PlatformBrowser implements Platform {
decode(bytes: Uint8Array, encoding: string): string {
return new TextDecoder(encoding).decode(bytes);
}

// If the setTimeout nesting level is greater than 5 and timeout is less
// than 4ms, timeout will be clamped to 4ms, which hurts the perf.
// Interleaving window.postMessage and setTimeout will trick the browser and
// avoid the clamp.
setTimeoutCustom(functionRef: Function, delay: number): void {
if (!window || !env().getBool('USE_SETTIMEOUTCUSTOM')) {
setTimeout(functionRef, delay);
return;
}

this.functionRefs.push(functionRef);
setTimeout(() => {
window.postMessage(
{name: this.messageName, index: this.functionRefs.length - 1}, '*');
}, delay);

if (!this.hasEventListener) {
this.hasEventListener = true;
window.addEventListener('message', (event: MessageEvent) => {
if (event.source === window && event.data.name === this.messageName) {
event.stopPropagation();
const functionRef = this.functionRefs[event.data.index];
functionRef();
this.handledMessageCount++;
if (this.handledMessageCount === this.functionRefs.length) {
this.functionRefs = [];
this.handledMessageCount = 0;
}
}
}, true);
}
}
}

if (env().get('IS_BROWSER')) {
Expand Down
72 changes: 72 additions & 0 deletions tfjs-core/src/platforms/platform_browser_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* =============================================================================
*/

import {env} from '../environment';
import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util';

import {PlatformBrowser} from './platform_browser';
Expand Down Expand Up @@ -88,3 +89,74 @@ describeWithFlags('PlatformBrowser', BROWSER_ENVS, async () => {
expect(s).toEqual('Здраво');
});
});

describe('setTimeout', () => {
const totalCount = 100;
// Skip the first few samples because the browser does not clamp the timeout
const skipCount = 5;

it('setTimeout', (done) => {
let count = 0;
let startTime = performance.now();
let totalTime = 0;
if (env().platformName === 'browser') {
setTimeout(_testSetTimeout, 0);
} else {
expect().nothing();
}

function _testSetTimeout() {
const endTime = performance.now();
count++;
if (count > skipCount) {
totalTime += endTime - startTime;
}
if (count === totalCount) {
const averageTime = totalTime / (totalCount - skipCount);
console.log(`averageTime of setTimeout is ${averageTime} ms`);
expect(averageTime).toBeGreaterThan(4);
done();
return;
}
startTime = performance.now();
setTimeout(_testSetTimeout, 0);
}
});

it('setTimeoutCustom', (done) => {
let count = 0;
let startTime = performance.now();
let totalTime = 0;
let originUseSettimeoutcustom: boolean;

if (env().platformName === 'browser') {
originUseSettimeoutcustom = env().getBool('USE_SETTIMEOUTCUSTOM');
env().set('USE_SETTIMEOUTCUSTOM', true);
env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0);
} else {
expect().nothing();
}

function _testSetTimeoutCustom() {
const endTime = performance.now();
count++;
if (count > skipCount) {
totalTime += endTime - startTime;
}
if (count === totalCount) {
const averageTime = totalTime / (totalCount - skipCount);
console.log(`averageTime of setTimeoutCustom is ${averageTime} ms`);
if (window) {
expect(averageTime).toBeLessThan(4);
} else {
expect(averageTime).toBeGreaterThan(4);
}
done();
env().set('USE_SETTIMEOUTCUSTOM', originUseSettimeoutcustom);
return;
}
startTime = performance.now();
env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0);
}
});
});
60 changes: 12 additions & 48 deletions tfjs-core/src/util_base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
* =============================================================================
*/

import {env} from './environment';
import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray} from './types';

/**
Expand Down Expand Up @@ -304,7 +303,7 @@ export function rightPad(a: string, size: number): string {

export function repeatedTry(
checkFn: () => boolean, delayFn = (counter: number) => 0,
maxCounter?: number): Promise<void> {
maxCounter?: number, scheduleFn?: Function): Promise<void> {
return new Promise<void>((resolve, reject) => {
let tryCount = 0;

Expand All @@ -322,10 +321,10 @@ export function repeatedTry(
reject();
return;
}
if (env().getBool('USE_SETTIMEOUTWPM')) {
(window as any).setTimeoutWPM(tryFn, nextBackoff);
} else {
if (typeof scheduleFn === 'undefined') {
setTimeout(tryFn, nextBackoff);
} else {
scheduleFn(tryFn, nextBackoff);
}
};

Expand Down Expand Up @@ -529,9 +528,9 @@ export function bytesPerElement(dtype: DataType): number {

/**
* Returns the approximate number of bytes allocated in the string array - 2
* bytes per character. Computing the exact bytes for a native string in JS is
* not possible since it depends on the encoding of the html page that serves
* the website.
* bytes per character. Computing the exact bytes for a native string in JS
* is not possible since it depends on the encoding of the html page that
* serves the website.
*/
export function bytesFromStringArray(arr: Uint8Array[]): number {
if (arr == null) {
Expand Down Expand Up @@ -717,8 +716,8 @@ export function locToIndex(
}

/**
* Computes the location (multidimensional index) in a tensor/multidimentional
* array for a given flat index.
* Computes the location (multidimensional index) in a
* tensor/multidimentional array for a given flat index.
*
* @param index Index in flat array.
* @param rank Rank of tensor.
Expand Down Expand Up @@ -749,43 +748,8 @@ export function isPromise(object: any): object is Promise<unknown> {
// We chose to not use 'obj instanceOf Promise' for two reasons:
// 1. It only reliably works for es6 Promise, not other Promise
// implementations.
// 2. It doesn't work with framework that uses zone.js. zone.js monkey patch
// the async calls, so it is possible the obj (patched) is comparing to a
// pre-patched Promise.
// 2. It doesn't work with framework that uses zone.js. zone.js monkey
// patch the async calls, so it is possible the obj (patched) is
// comparing to a pre-patched Promise.
return object && object.then && typeof object.then === 'function';
}

if (!('setTimeoutWPM' in window)) {
const messageName = 'setTimeoutWPM';
let fns: Function[] = [];
let handledMessageCount = 0;

// If the setTimeout nesting level is greater than 5 and timeout is less than
// 4ms, timeout will be clamped to 4ms, which hurts the perf. Interleaving
// window.postMessage and setTimeout will trick the browser and avoid the
// clamp.
const setTimeoutWPM = function(fn: Function, timeout: number) {
fns.push(fn);
setTimeout(() => {
window.postMessage({name: messageName, index: fns.length - 1}, '*');
}, timeout);
};

const handleMessage = function(event: MessageEvent) {
if (event.source == window && event.data.name == messageName) {
event.stopPropagation();
const fn = fns[event.data.index];
fn();
handledMessageCount++;
// console.log(`handledMessageCount=${handledMessageCount},
// fns.length=${fns.length}`);
if (handledMessageCount === fns.length) {
fns = [];
handledMessageCount = 0;
}
}
};

window.addEventListener('message', handleMessage, true);
(window as any).setTimeoutWPM = setTimeoutWPM;
}
55 changes: 0 additions & 55 deletions tfjs-core/src/util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -665,58 +665,3 @@ describe('util.decodeString', () => {
expect(util.isPromise(promise3)).toBeFalsy();
});
});

describe('setTimeout', () => {
// If we set a larger number here, an error will be reported as "'expect'
// was used when there was no current spec, this could be because an
// asynchronous test timed out".
const totalCount = 8;
const skipCount = 5;

it('setTimeout', () => {
let count = 0;
let startTime = performance.now();
let totalTime = 0;
setTimeout(_testSetTimeout, 0);

function _testSetTimeout() {
let endTime = performance.now();
count++;
if (count > skipCount) {
totalTime += endTime - startTime;
}
if (count === totalCount) {
let averageTime = totalTime / (totalCount - skipCount);
console.log(`averageTime of setTimeout is ${averageTime} ms`);
// We don't have expect here as in some browsers, like Chrome Canary,
// nesting level threshold is set to 100 instead of 5.
return;
}
startTime = performance.now();
setTimeout(_testSetTimeout, 0);
}
});

it('setTimeoutWPM', () => {
let count = 0;
let startTime = performance.now();
let totalTime = 0;
(window as any).setTimeoutWPM(_testSetTimeoutWPM, 0);

function _testSetTimeoutWPM() {
let endTime = performance.now();
count++;
if (count > skipCount) {
totalTime += endTime - startTime;
}
if (count === totalCount) {
let averageTime = totalTime / (totalCount - skipCount);
console.log(`averageTime of setTimeoutWPM is ${averageTime} ms`);
expect(averageTime).toBeLessThan(4);
return;
}
startTime = performance.now();
(window as any).setTimeoutWPM(_testSetTimeoutWPM, 0);
}
});
});

0 comments on commit b87bf71

Please sign in to comment.