Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGL backend] Add max 1D texture dimension flag #6808

Merged
merged 16 commits into from
Sep 1, 2022
Merged
22 changes: 22 additions & 0 deletions tfjs-backend-webgl/src/flags_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,25 @@ ENV.registerFlag('WEBGL_EXP_CONV', () => false);
* software WebGL will be used.
*/
ENV.registerFlag('SOFTWARE_WEBGL_ENABLED', () => ENV.getBool('IS_TEST'));

/**
* For narrow texture (physical height or physical width is 1), if the length of
* any texture edges exceed the threshold, the texture will be reshaped to be
* more squarish.
*
* This flag is used to help some GPUs that could not provide correct
* interpolations for long skinny triangles. We found Mali GPU probably has this
* problem: https://github.com/tensorflow/tfjs/issues/6775.
*/
ENV.registerFlag('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', () => Infinity);

/**
* If the flag is set to true, the max size of the narrow texture will be auto
* computed and it will be considerred as a threshold to reshape the narrow
* texture to be more squarish.
*
* This flag is used to help some GPUs that could not provide correct
* interpolations for long skinny triangles. We found Mali GPU probably has this
* problem: https://github.com/tensorflow/tfjs/issues/6775.
*/
ENV.registerFlag('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', () => false);
35 changes: 35 additions & 0 deletions tfjs-backend-webgl/src/flags_webgl_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,38 @@ describe('WEBGL_EXP_CONV', () => {
expect(tf.env().getBool('WEBGL_EXP_CONV')).toBe(false);
});
});

const MAX_SIZE_FOR_NARROR_TEX_FLAG = 'WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE';

describeWithFlags(MAX_SIZE_FOR_NARROR_TEX_FLAG, WEBGL_ENVS, () => {
beforeEach(() => tf.env().reset());
afterAll(() => tf.env().reset());

it(`returns correct value when ${MAX_SIZE_FOR_NARROR_TEX_FLAG} is set`,
() => {
tf.env().set(MAX_SIZE_FOR_NARROR_TEX_FLAG, 2048);
expect(tf.env().getNumber(MAX_SIZE_FOR_NARROR_TEX_FLAG)).toBe(2048);
});

it(`returns default when ${MAX_SIZE_FOR_NARROR_TEX_FLAG} is not set`, () => {
expect(tf.env().getNumber(MAX_SIZE_FOR_NARROR_TEX_FLAG)).toBe(Infinity);
});
});

const AUTO_SQUARIFY_NARROW_TEX_FLAG =
'WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE';

describeWithFlags(AUTO_SQUARIFY_NARROW_TEX_FLAG, WEBGL_ENVS, () => {
beforeEach(() => tf.env().reset());
afterAll(() => tf.env().reset());

it(`returns correct value when ${AUTO_SQUARIFY_NARROW_TEX_FLAG} is set`,
() => {
tf.env().set(AUTO_SQUARIFY_NARROW_TEX_FLAG, true);
expect(tf.env().getBool(AUTO_SQUARIFY_NARROW_TEX_FLAG)).toBe(true);
});

it(`returns default when ${AUTO_SQUARIFY_NARROW_TEX_FLAG} is not set`, () => {
expect(tf.env().getBool(AUTO_SQUARIFY_NARROW_TEX_FLAG)).toBe(false);
});
});
40 changes: 31 additions & 9 deletions tfjs-backend-webgl/src/webgl_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,16 @@ export function getShapeAs3D(shape: number[]): [number, number, number] {
export function getTextureShapeFromLogicalShape(
logShape: number[], isPacked = false): [number, number] {
let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
let maxSizeForNarrowTex =
env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE');
if (maxSizeForNarrowTex === Infinity &&
env().getBool('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE')) {
maxSizeForNarrowTex = maxTexSize / 2;
}

if (isPacked) {
maxTexSize = maxTexSize * 2;
maxSizeForNarrowTex = maxSizeForNarrowTex * 2;

// This logic ensures we accurately count the number of packed texels needed
// to accommodate the tensor. We can only pack values in the same texel if
Expand All @@ -395,30 +403,40 @@ export function getTextureShapeFromLogicalShape(
}

let size = util.sizeFromShape(logShape);
let textureShape: [number, number] = null;
if (logShape.length <= 1 && size <= maxTexSize) {
return [1, size];
textureShape = [1, size];
} else if (
logShape.length === 2 && logShape[0] <= maxTexSize &&
logShape[1] <= maxTexSize) {
return logShape as [number, number];
textureShape = logShape as [number, number];
} else if (
logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize &&
logShape[2] <= maxTexSize) {
return [logShape[0] * logShape[1], logShape[2]];
textureShape = [logShape[0] * logShape[1], logShape[2]];
} else if (
logShape.length === 3 && logShape[0] <= maxTexSize &&
logShape[1] * logShape[2] <= maxTexSize) {
return [logShape[0], logShape[1] * logShape[2]];
textureShape = [logShape[0], logShape[1] * logShape[2]];
} else if (
logShape.length === 4 &&
logShape[0] * logShape[1] * logShape[2] <= maxTexSize &&
logShape[3] <= maxTexSize) {
return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
textureShape = [logShape[0] * logShape[1] * logShape[2], logShape[3]];
} else if (
logShape.length === 4 && logShape[0] <= maxTexSize &&
logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
} else {
textureShape = [logShape[0], logShape[1] * logShape[2] * logShape[3]];
}

// true if one edge length is 1 (1 or 2, if packed), while another edge
// length exceeds maxSizeForNarrowTex.
const isLongNarrowTex = textureShape != null &&
Math.max(...textureShape) > maxSizeForNarrowTex &&
Math.min(...textureShape) <= (isPacked ? 2 : 1) &&
Math.min(...textureShape) > 0;

if (textureShape == null || isLongNarrowTex) {
if (isPacked) {
// For packed textures size equals the number of channels required to
// accommodate the texture data. However in order to squarify such that
Expand All @@ -432,10 +450,14 @@ export function getTextureShapeFromLogicalShape(
[rows, cols] = getRowsCols(logShape);
}
size = batchDim * (rows / 2) * (cols / 2);
return util.sizeToSquarishShape(size).map(d => d * 2) as [number, number];
textureShape =
util.sizeToSquarishShape(size).map(d => d * 2) as [number, number];
} else {
textureShape = util.sizeToSquarishShape(size);
}
return util.sizeToSquarishShape(size);
}

return textureShape;
}

function isEven(n: number): boolean {
Expand Down
34 changes: 34 additions & 0 deletions tfjs-backend-webgl/src/webgl_util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,40 @@ describeWithFlags('getTextureShapeFromLogicalShape packed', WEBGL_ENVS, () => {
tf.env().set('WEBGL_MAX_TEXTURE_SIZE', max);
expect(texShape).toEqual([6, 4]);
});

it('squarified long narrow texture shapes', () => {
const isPacked = true;
const max = tf.env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE');

tf.env().set('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', 5);
const logicalShape = [1, 16];
const texShape =
webgl_util.getTextureShapeFromLogicalShape(logicalShape, isPacked);

tf.env().set('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', max);
expect(texShape).toEqual([6, 6]);
});

it('auto squarified long narrow texture shapes', () => {
const isPacked = true;
const max = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
const maxForNarrowTex =
tf.env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE');
const autoSquarify =
tf.env().getNumber('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE');

tf.env().set('WEBGL_MAX_TEXTURE_SIZE', 6);
tf.env().set('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', true);
tf.env().set('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', Infinity);
const logicalShape = [1, 16];
const texShape =
webgl_util.getTextureShapeFromLogicalShape(logicalShape, isPacked);

tf.env().set('WEBGL_MAX_TEXTURE_SIZE', max);
tf.env().set('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', maxForNarrowTex);
tf.env().set('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', autoSquarify);
expect(texShape).toEqual([6, 6]);
});
});

describeWithFlags('isReshapeFree', WEBGL_ENVS, () => {
Expand Down