diff --git a/tfjs-backend-webgl/src/flags_webgl.ts b/tfjs-backend-webgl/src/flags_webgl.ts index dbfbcf95c72..6846e1d64a8 100644 --- a/tfjs-backend-webgl/src/flags_webgl.ts +++ b/tfjs-backend-webgl/src/flags_webgl.ts @@ -246,3 +246,25 @@ ENV.registerFlag('WEBGL_EXP_CONV', () => false); * software WebGL will be used. */ ENV.registerFlag('SOFTWARE_WEBGL_ENABLED', () => ENV.getBool('IS_TEST')); + +/** + * For narrow texture (physical height or physical width is 1), if the length of + * any texture edges exceed the threshold, the texture will be reshaped to be + * more squarish. + * + * This flag is used to help some GPUs that could not provide correct + * interpolations for long skinny triangles. We found Mali GPU probably has this + * problem: https://github.com/tensorflow/tfjs/issues/6775. + */ +ENV.registerFlag('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', () => Infinity); + +/** + * If the flag is set to true, the max size of the narrow texture will be auto + * computed and it will be considerred as a threshold to reshape the narrow + * texture to be more squarish. + * + * This flag is used to help some GPUs that could not provide correct + * interpolations for long skinny triangles. We found Mali GPU probably has this + * problem: https://github.com/tensorflow/tfjs/issues/6775. + */ +ENV.registerFlag('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', () => false); diff --git a/tfjs-backend-webgl/src/flags_webgl_test.ts b/tfjs-backend-webgl/src/flags_webgl_test.ts index 3ce0dc4b0ca..a8568863efc 100644 --- a/tfjs-backend-webgl/src/flags_webgl_test.ts +++ b/tfjs-backend-webgl/src/flags_webgl_test.ts @@ -444,3 +444,38 @@ describe('WEBGL_EXP_CONV', () => { expect(tf.env().getBool('WEBGL_EXP_CONV')).toBe(false); }); }); + +const MAX_SIZE_FOR_NARROR_TEX_FLAG = 'WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE'; + +describeWithFlags(MAX_SIZE_FOR_NARROR_TEX_FLAG, WEBGL_ENVS, () => { + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); + + it(`returns correct value when ${MAX_SIZE_FOR_NARROR_TEX_FLAG} is set`, + () => { + tf.env().set(MAX_SIZE_FOR_NARROR_TEX_FLAG, 2048); + expect(tf.env().getNumber(MAX_SIZE_FOR_NARROR_TEX_FLAG)).toBe(2048); + }); + + it(`returns default when ${MAX_SIZE_FOR_NARROR_TEX_FLAG} is not set`, () => { + expect(tf.env().getNumber(MAX_SIZE_FOR_NARROR_TEX_FLAG)).toBe(Infinity); + }); +}); + +const AUTO_SQUARIFY_NARROW_TEX_FLAG = + 'WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE'; + +describeWithFlags(AUTO_SQUARIFY_NARROW_TEX_FLAG, WEBGL_ENVS, () => { + beforeEach(() => tf.env().reset()); + afterAll(() => tf.env().reset()); + + it(`returns correct value when ${AUTO_SQUARIFY_NARROW_TEX_FLAG} is set`, + () => { + tf.env().set(AUTO_SQUARIFY_NARROW_TEX_FLAG, true); + expect(tf.env().getBool(AUTO_SQUARIFY_NARROW_TEX_FLAG)).toBe(true); + }); + + it(`returns default when ${AUTO_SQUARIFY_NARROW_TEX_FLAG} is not set`, () => { + expect(tf.env().getBool(AUTO_SQUARIFY_NARROW_TEX_FLAG)).toBe(false); + }); +}); diff --git a/tfjs-backend-webgl/src/webgl_util.ts b/tfjs-backend-webgl/src/webgl_util.ts index 52c8bd4d8df..31df6777eeb 100644 --- a/tfjs-backend-webgl/src/webgl_util.ts +++ b/tfjs-backend-webgl/src/webgl_util.ts @@ -368,8 +368,16 @@ export function getShapeAs3D(shape: number[]): [number, number, number] { export function getTextureShapeFromLogicalShape( logShape: number[], isPacked = false): [number, number] { let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + let maxSizeForNarrowTex = + env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE'); + if (maxSizeForNarrowTex === Infinity && + env().getBool('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE')) { + maxSizeForNarrowTex = maxTexSize / 2; + } + if (isPacked) { maxTexSize = maxTexSize * 2; + maxSizeForNarrowTex = maxSizeForNarrowTex * 2; // This logic ensures we accurately count the number of packed texels needed // to accommodate the tensor. We can only pack values in the same texel if @@ -395,30 +403,40 @@ export function getTextureShapeFromLogicalShape( } let size = util.sizeFromShape(logShape); + let textureShape: [number, number] = null; if (logShape.length <= 1 && size <= maxTexSize) { - return [1, size]; + textureShape = [1, size]; } else if ( logShape.length === 2 && logShape[0] <= maxTexSize && logShape[1] <= maxTexSize) { - return logShape as [number, number]; + textureShape = logShape as [number, number]; } else if ( logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize && logShape[2] <= maxTexSize) { - return [logShape[0] * logShape[1], logShape[2]]; + textureShape = [logShape[0] * logShape[1], logShape[2]]; } else if ( logShape.length === 3 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] <= maxTexSize) { - return [logShape[0], logShape[1] * logShape[2]]; + textureShape = [logShape[0], logShape[1] * logShape[2]]; } else if ( logShape.length === 4 && logShape[0] * logShape[1] * logShape[2] <= maxTexSize && logShape[3] <= maxTexSize) { - return [logShape[0] * logShape[1] * logShape[2], logShape[3]]; + textureShape = [logShape[0] * logShape[1] * logShape[2], logShape[3]]; } else if ( logShape.length === 4 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] * logShape[3] <= maxTexSize) { - return [logShape[0], logShape[1] * logShape[2] * logShape[3]]; - } else { + textureShape = [logShape[0], logShape[1] * logShape[2] * logShape[3]]; + } + + // true if one edge length is 1 (1 or 2, if packed), while another edge + // length exceeds maxSizeForNarrowTex. + const isLongNarrowTex = textureShape != null && + Math.max(...textureShape) > maxSizeForNarrowTex && + Math.min(...textureShape) <= (isPacked ? 2 : 1) && + Math.min(...textureShape) > 0; + + if (textureShape == null || isLongNarrowTex) { if (isPacked) { // For packed textures size equals the number of channels required to // accommodate the texture data. However in order to squarify such that @@ -432,10 +450,14 @@ export function getTextureShapeFromLogicalShape( [rows, cols] = getRowsCols(logShape); } size = batchDim * (rows / 2) * (cols / 2); - return util.sizeToSquarishShape(size).map(d => d * 2) as [number, number]; + textureShape = + util.sizeToSquarishShape(size).map(d => d * 2) as [number, number]; + } else { + textureShape = util.sizeToSquarishShape(size); } - return util.sizeToSquarishShape(size); } + + return textureShape; } function isEven(n: number): boolean { diff --git a/tfjs-backend-webgl/src/webgl_util_test.ts b/tfjs-backend-webgl/src/webgl_util_test.ts index 1377a4b25f0..d57d0b181f6 100644 --- a/tfjs-backend-webgl/src/webgl_util_test.ts +++ b/tfjs-backend-webgl/src/webgl_util_test.ts @@ -114,6 +114,40 @@ describeWithFlags('getTextureShapeFromLogicalShape packed', WEBGL_ENVS, () => { tf.env().set('WEBGL_MAX_TEXTURE_SIZE', max); expect(texShape).toEqual([6, 4]); }); + + it('squarified long narrow texture shapes', () => { + const isPacked = true; + const max = tf.env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE'); + + tf.env().set('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', 5); + const logicalShape = [1, 16]; + const texShape = + webgl_util.getTextureShapeFromLogicalShape(logicalShape, isPacked); + + tf.env().set('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', max); + expect(texShape).toEqual([6, 6]); + }); + + it('auto squarified long narrow texture shapes', () => { + const isPacked = true; + const max = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + const maxForNarrowTex = + tf.env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE'); + const autoSquarify = + tf.env().getNumber('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE'); + + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', 6); + tf.env().set('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', true); + tf.env().set('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', Infinity); + const logicalShape = [1, 16]; + const texShape = + webgl_util.getTextureShapeFromLogicalShape(logicalShape, isPacked); + + tf.env().set('WEBGL_MAX_TEXTURE_SIZE', max); + tf.env().set('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', maxForNarrowTex); + tf.env().set('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', autoSquarify); + expect(texShape).toEqual([6, 6]); + }); }); describeWithFlags('isReshapeFree', WEBGL_ENVS, () => {