From 7cddfa674a2993473debb1e3799bec488dc7fdd1 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Mon, 28 Aug 2023 14:36:24 -0700 Subject: [PATCH 1/5] Use describeWithFlags --- .../layers/nlp/multihead_attention_test.ts | 196 +++++++++--------- 1 file changed, 98 insertions(+), 98 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/multihead_attention_test.ts b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts index a389d1ab52c..420d141d798 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention_test.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts @@ -25,10 +25,10 @@ import { TruncatedNormal } from '../../initializers'; import { input } from '../../exports'; import { Shape } from '../../keras_format/common'; import { MultiHeadAttention } from './multihead_attention'; -import { describeMathCPU, expectTensorsClose, expectTensorsNotClose } from '../../utils/test_utils'; +import { describeMathCPU, describeMathCPUAndGPU, expectTensorsClose, expectTensorsNotClose } from '../../utils/test_utils'; import { Embedding } from '../embeddings'; -describe('MultiHeadAttention', () => { +describeMathCPUAndGPU('MultiHeadAttention', () => { describe('Non Masked Attention', () => { interface NonMaskedAttentionArgs { @@ -188,101 +188,6 @@ describe('MultiHeadAttention', () => { expectTensorsNotClose(queryKernel, outputKernel, 1e-6); }); - describeMathCPU('High Dimensional Attention', () => { - interface HighDimAttentionArgs { - testcaseName: string; - qDims: Shape; - vDims: Shape; - maskDims: Shape; - attentionAxes: number[]; - } - /** - * Test with high dimensional inputs. - */ - function testHighDimAttention({ - testcaseName, qDims, vDims, maskDims, attentionAxes, - }: HighDimAttentionArgs) { - it(testcaseName, () => { - const testLayer = new MultiHeadAttention({ - numHeads: 2, keyDim: 2, attentionAxes, - }); - const batchSize = 3; - const hiddenSize = 8; - // Generate data for the input (non-mask) tensors. - const queryShape = [batchSize].concat(qDims).concat(hiddenSize); - const valueShape = [batchSize].concat(vDims).concat(hiddenSize); - const maskShape = [batchSize].concat(maskDims); - const query = randomUniform(queryShape, 0, 10); - const value = randomUniform(valueShape, 0, 10); - - // Invoke the data with a random set of mask data. This should mask at - // least one element. - const maskData = randomUniformInt(maskShape, 0, 2).asType('bool'); - - // Invoke the same data, but with a null mask (where no elements are - // masked). - const nullMaskData = ones(maskShape); - - // Because one data is masked and one is not, the outputs should not be - // the same. - - const outputWithMask = testLayer.call( - query, {value, attentionMask: maskData}); - const outputWithNullMask = testLayer.call( - query, {value, attentionMask: nullMaskData}); - - expectTensorsNotClose(outputWithMask, outputWithNullMask); - }); - } - const params: HighDimAttentionArgs[] = [ - { - testcaseName: '4d_inputs_1freebatch_mask2', - qDims: [3, 4], - vDims: [3, 2], - maskDims: [4, 2], - attentionAxes: [2], - }, - { - testcaseName: '4d_inputs_1freebatch_mask3', - qDims: [3, 4], - vDims: [3, 2], - maskDims: [3, 4, 2], - attentionAxes: [2], - }, - { - testcaseName: '4d_inputs_1freebatch_mask4', - qDims: [3, 4], - vDims: [3, 2], - maskDims: [3, 2, 4, 2], - attentionAxes: [2], - }, - { - testcaseName: '4D_inputs_2D_attention', - qDims: [3, 4], - vDims: [3, 2], - maskDims: [3, 4, 3, 2], - attentionAxes: [1, 2], - }, - { - testcaseName: '5D_inputs_2D_attention', - qDims: [5, 3, 4], - vDims: [5, 3, 2], - maskDims: [3, 4, 3, 2], - attentionAxes: [2, 3], - }, - { - testcaseName: '5D_inputs_2D_attention_fullmask', - qDims: [5, 3, 4], - vDims: [5, 3, 2], - maskDims: [5, 3, 4, 3, 2], - attentionAxes: [2, 3], - }, - ]; - for (const param of params) { - testHighDimAttention(param); - } - }); - it('dropout', () => { const testLayer = new MultiHeadAttention({ numHeads: 2, @@ -298,7 +203,7 @@ describe('MultiHeadAttention', () => { expectTensorsNotClose(trainOut, testOut); }); - describe('Casual Mask Value', () => { + fdescribe('Casual Mask Value', () => { /** * Test that the value and causal masks are taken into account. */ @@ -482,6 +387,101 @@ describe('MultiHeadAttention', () => { // TODO(pforderique): Test serialization. }); +describeMathCPU('High Dimensional Attention', () => { + interface HighDimAttentionArgs { + testcaseName: string; + qDims: Shape; + vDims: Shape; + maskDims: Shape; + attentionAxes: number[]; + } + /** + * Test with high dimensional inputs. + */ + function testHighDimAttention({ + testcaseName, qDims, vDims, maskDims, attentionAxes, + }: HighDimAttentionArgs) { + it(testcaseName, () => { + const testLayer = new MultiHeadAttention({ + numHeads: 2, keyDim: 2, attentionAxes, + }); + const batchSize = 3; + const hiddenSize = 8; + // Generate data for the input (non-mask) tensors. + const queryShape = [batchSize].concat(qDims).concat(hiddenSize); + const valueShape = [batchSize].concat(vDims).concat(hiddenSize); + const maskShape = [batchSize].concat(maskDims); + const query = randomUniform(queryShape, 0, 10); + const value = randomUniform(valueShape, 0, 10); + + // Invoke the data with a random set of mask data. This should mask at + // least one element. + const maskData = randomUniformInt(maskShape, 0, 2).asType('bool'); + + // Invoke the same data, but with a null mask (where no elements are + // masked). + const nullMaskData = ones(maskShape); + + // Because one data is masked and one is not, the outputs should not be + // the same. + + const outputWithMask = testLayer.call( + query, {value, attentionMask: maskData}); + const outputWithNullMask = testLayer.call( + query, {value, attentionMask: nullMaskData}); + + expectTensorsNotClose(outputWithMask, outputWithNullMask); + }); + } + const params: HighDimAttentionArgs[] = [ + { + testcaseName: '4d_inputs_1freebatch_mask2', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [4, 2], + attentionAxes: [2], + }, + { + testcaseName: '4d_inputs_1freebatch_mask3', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [3, 4, 2], + attentionAxes: [2], + }, + { + testcaseName: '4d_inputs_1freebatch_mask4', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [3, 2, 4, 2], + attentionAxes: [2], + }, + { + testcaseName: '4D_inputs_2D_attention', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [3, 4, 3, 2], + attentionAxes: [1, 2], + }, + { + testcaseName: '5D_inputs_2D_attention', + qDims: [5, 3, 4], + vDims: [5, 3, 2], + maskDims: [3, 4, 3, 2], + attentionAxes: [2, 3], + }, + { + testcaseName: '5D_inputs_2D_attention_fullmask', + qDims: [5, 3, 4], + vDims: [5, 3, 2], + maskDims: [5, 3, 4, 3, 2], + attentionAxes: [2, 3], + }, + ]; + for (const param of params) { + testHighDimAttention(param); + } +}); + class SubclassAttention extends MultiHeadAttention { protected override buildAttention(qkvRank: number) {} From 0132ef56a27027e8064c7d0674e7703a9dd148b9 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Mon, 11 Sep 2023 14:09:16 -0700 Subject: [PATCH 2/5] Separate tests out of for loop --- .../layers/nlp/multihead_attention_test.ts | 87 ++++++++++--------- 1 file changed, 47 insertions(+), 40 deletions(-) diff --git a/tfjs-layers/src/layers/nlp/multihead_attention_test.ts b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts index 420d141d798..715b2f5e083 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention_test.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts @@ -203,52 +203,59 @@ describeMathCPUAndGPU('MultiHeadAttention', () => { expectTensorsNotClose(trainOut, testOut); }); - fdescribe('Casual Mask Value', () => { + fdescribe('Causal Mask Value', () => { + let testLayer: MultiHeadAttention; + let maskedQuery: Tensor; + let maskedValue: Tensor; + let mask: Tensor; + + beforeEach(() => { + testLayer = new MultiHeadAttention({numHeads: 2, keyDim: 2}); + const query = tensor2d([ + [1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0] + ]); + maskedQuery = new Embedding( + {inputDim: 4, outputDim: 8, maskZero: true}).apply(query) as Tensor; + const value = tensor2d([[5, 4, 0], [3, 0, 0], [2, 1, 1]]); + maskedValue = new Embedding( + {inputDim: 6, outputDim: 8, maskZero: true}).apply(value) as Tensor; + + mask = tensor([ + Array(3).fill([true, true, false]).concat( + Array(2).fill([false, false, false])), + Array(5).fill([true, false, false]), + [[true, true, true]].concat( + Array(4).fill([false, false, false])) + ]); + }); + /** * Test that the value and causal masks are taken into account. */ - function testValueMask(testcaseName: string, useCausalMask: boolean) { - it(testcaseName, () => { - const testLayer = new MultiHeadAttention({numHeads: 2, keyDim: 2}); - const query = tensor2d([ - [1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0] - ]); - const maskedQuery = new Embedding( - {inputDim: 4, outputDim: 8, maskZero: true}).apply(query) as Tensor; - const value = tensor2d([[5, 4, 0], [3, 0, 0], [2, 1, 1]]); - const maskedValue = new Embedding( - {inputDim: 6, outputDim: 8, maskZero: true}).apply(value) as Tensor; - - const output = testLayer.call( - maskedQuery, {value: maskedValue, useCausalMask: true}); - - let mask = tensor([ - Array(3).fill([true, true, false]).concat( - Array(2).fill([false, false, false])), - Array(5).fill([true, false, false]), - [[true, true, true]].concat( - Array(4).fill([false, false, false])) - ]); - if (useCausalMask) { - mask = mask.logicalAnd(tensor([ - [[true, false, false], [true, true, false]].concat( - [[true, true, true], [true, true, true], [true, true, true]]) - ])); - } + fit('causal', () => { + const output = testLayer.call( + maskedQuery, {value: maskedValue, useCausalMask: true}); - const outputWithManualMask = testLayer.call( - maskedQuery, {value: maskedValue, attentionMask: mask}); + const outputWithManualMask = testLayer.call( + maskedQuery, {value: maskedValue, attentionMask: mask}); - expectTensorsClose(output, outputWithManualMask); - }); - } + expectTensorsClose(output, outputWithManualMask); + }); - const params: Array<[string, boolean]> = [ - ['casual', true], ['not_casual', false] - ]; - for (const [testName, useMask] of params) { - testValueMask(testName, useMask); - } + it('not_causal', () => { + const output = testLayer.call( + maskedQuery, {value: maskedValue, useCausalMask: true}); + + mask = mask.logicalAnd(tensor([ + [[true, false, false], [true, true, false]].concat( + [[true, true, true], [true, true, true], [true, true, true]]) + ])); + + const outputWithManualMask = testLayer.call( + maskedQuery, {value: maskedValue, attentionMask: mask}); + + expectTensorsClose(output, outputWithManualMask); + }); }); describe('Compute Output Shape', () => { From 1b89d25aaf6449e5f996ca26b49b34923757800d Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Mon, 11 Sep 2023 23:11:10 -0700 Subject: [PATCH 3/5] Make MHA use masks from query and value tensors --- tfjs-core/src/tensor.ts | 5 ++++ tfjs-layers/src/engine/topology.ts | 26 ++++++++++++++++++- .../src/layers/nlp/multihead_attention.ts | 14 +++++++--- .../layers/nlp/multihead_attention_test.ts | 21 ++++++++------- 4 files changed, 52 insertions(+), 14 deletions(-) diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 59cc89c5bd8..02cee9fe649 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -275,6 +275,8 @@ export class Tensor implements TensorInfo { kept = false; /** The id of the scope this tensor is being tracked in. */ scopeId: number; + /** The keras mask that some keras layers attach to the tensor */ + keras_mask?: Tensor; /** * Number of elements to skip in each dimension when indexing. See @@ -442,6 +444,9 @@ export class Tensor implements TensorInfo { if (this.isDisposed) { return; } + if (this.keras_mask) { + this.keras_mask.dispose(); + } trackerFn().disposeTensor(this); this.isDisposedInternal = true; } diff --git a/tfjs-layers/src/engine/topology.ts b/tfjs-layers/src/engine/topology.ts index 37b64663750..0de81e15171 100644 --- a/tfjs-layers/src/engine/topology.ts +++ b/tfjs-layers/src/engine/topology.ts @@ -1018,7 +1018,31 @@ export abstract class Layer extends serialization.Serializable { // Actually call the layer, collecting output(s), mask(s), and shape(s). if (noneAreSymbolic) { let output = this.call(inputs as Tensor | Tensor[], kwargs); - // TODO(michaelterry): Compute the outputMask + + // Apply masks to the output tensors if the layer supports it. + if (this.supportsMasking) { + // TODO(mattsoulanille): pass the input tensors' masks to computeMask + const outputMask = this.computeMask(inputs as Tensor | Tensor[]); + if (output instanceof Array && outputMask instanceof Array) { + if (output.length !== outputMask.length) { + throw new Error(`${this.name} output ${output.length} tensors ` + + `but ${outputMask.length} masks for those tensors`); + } + for (let i = 0; i < output.length; i++) { + output[i].keras_mask = outputMask[i]; + } + } else if (outputMask instanceof Array) { + throw new Error(`{this.name} output a single tensor ` + + `but ${outputMask.length} masks`); + } else if (output instanceof Array) { + for (const out of output) { + out.keras_mask = outputMask.clone(); + } + outputMask.dispose(); // Only keep the clones to avoid leaking + } else { + output.keras_mask = outputMask; + } + } // If the layer returns tensors from its inputs, unmodified, // we copy them to avoid loss of tensor metadata. diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index b89d472f27d..e86cb565bb0 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -20,7 +20,7 @@ */ /* Original source: keras/layers/attention/multi_head_attention.py */ -import { Tensor, einsum, linalg, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core'; +import { Tensor, einsum, linalg, logicalAnd, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core'; // tslint:disable-next-line: no-imports-from-dist import { arraysEqual } from '@tensorflow/tfjs-core/dist/util_base'; @@ -813,12 +813,20 @@ export class MultiHeadAttention extends Layer { return tidy(() => { let autoMask: Tensor; + const queryMask = query.keras_mask; + const valueMask = value.keras_mask; + if (queryMask != null) { + autoMask = queryMask.expandDims(2); // Shape is [B, T, 1] + } + if (valueMask != null) { + const mask = valueMask.expandDims(1); // Shape is [B, 1, S] + autoMask = autoMask ? logicalAnd(autoMask, mask) : mask; + } if (useCausalMask) { // the shape of the causal mask is [1, T, S] const mask = this.computeCausalMask(query, value); - autoMask = mask; + autoMask = autoMask ? logicalAnd(autoMask, mask) : mask; } - if (autoMask != null) { // Merge attentionMask & automatic mask, to shape [B, T, S] attentionMask = attentionMask ? diff --git a/tfjs-layers/src/layers/nlp/multihead_attention_test.ts b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts index 715b2f5e083..1cc5d77f0e5 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention_test.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts @@ -203,7 +203,7 @@ describeMathCPUAndGPU('MultiHeadAttention', () => { expectTensorsNotClose(trainOut, testOut); }); - fdescribe('Causal Mask Value', () => { + describe('Causal Mask Value', () => { let testLayer: MultiHeadAttention; let maskedQuery: Tensor; let maskedValue: Tensor; @@ -214,8 +214,9 @@ describeMathCPUAndGPU('MultiHeadAttention', () => { const query = tensor2d([ [1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0] ]); - maskedQuery = new Embedding( - {inputDim: 4, outputDim: 8, maskZero: true}).apply(query) as Tensor; + const maskedQueryLayer = new Embedding( + {inputDim: 4, outputDim: 8, maskZero: true}); + maskedQuery = maskedQueryLayer.apply(query) as Tensor; const value = tensor2d([[5, 4, 0], [3, 0, 0], [2, 1, 1]]); maskedValue = new Embedding( {inputDim: 6, outputDim: 8, maskZero: true}).apply(value) as Tensor; @@ -232,10 +233,15 @@ describeMathCPUAndGPU('MultiHeadAttention', () => { /** * Test that the value and causal masks are taken into account. */ - fit('causal', () => { + it('causal', () => { const output = testLayer.call( maskedQuery, {value: maskedValue, useCausalMask: true}); + mask = mask.logicalAnd(tensor([ + [[true, false, false], [true, true, false]].concat( + [[true, true, true], [true, true, true], [true, true, true]]) + ])); + const outputWithManualMask = testLayer.call( maskedQuery, {value: maskedValue, attentionMask: mask}); @@ -244,12 +250,7 @@ describeMathCPUAndGPU('MultiHeadAttention', () => { it('not_causal', () => { const output = testLayer.call( - maskedQuery, {value: maskedValue, useCausalMask: true}); - - mask = mask.logicalAnd(tensor([ - [[true, false, false], [true, true, false]].concat( - [[true, true, true], [true, true, true], [true, true, true]]) - ])); + maskedQuery, {value: maskedValue, useCausalMask: false}); const outputWithManualMask = testLayer.call( maskedQuery, {value: maskedValue, attentionMask: mask}); From bdf6ed1cd0f211f51dc3761e1a629c050ebc1df2 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 12 Sep 2023 00:03:27 -0700 Subject: [PATCH 4/5] Fix lint --- tfjs-core/src/tensor.ts | 6 +++--- tfjs-layers/src/engine/topology.ts | 6 +++--- tfjs-layers/src/layers/nlp/multihead_attention.ts | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 02cee9fe649..137d3b48705 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -276,7 +276,7 @@ export class Tensor implements TensorInfo { /** The id of the scope this tensor is being tracked in. */ scopeId: number; /** The keras mask that some keras layers attach to the tensor */ - keras_mask?: Tensor; + kerasMask?: Tensor; /** * Number of elements to skip in each dimension when indexing. See @@ -444,8 +444,8 @@ export class Tensor implements TensorInfo { if (this.isDisposed) { return; } - if (this.keras_mask) { - this.keras_mask.dispose(); + if (this.kerasMask) { + this.kerasMask.dispose(); } trackerFn().disposeTensor(this); this.isDisposedInternal = true; diff --git a/tfjs-layers/src/engine/topology.ts b/tfjs-layers/src/engine/topology.ts index 0de81e15171..85ba63ccd66 100644 --- a/tfjs-layers/src/engine/topology.ts +++ b/tfjs-layers/src/engine/topology.ts @@ -1029,18 +1029,18 @@ export abstract class Layer extends serialization.Serializable { + `but ${outputMask.length} masks for those tensors`); } for (let i = 0; i < output.length; i++) { - output[i].keras_mask = outputMask[i]; + output[i].kerasMask = outputMask[i]; } } else if (outputMask instanceof Array) { throw new Error(`{this.name} output a single tensor ` + `but ${outputMask.length} masks`); } else if (output instanceof Array) { for (const out of output) { - out.keras_mask = outputMask.clone(); + out.kerasMask = outputMask.clone(); } outputMask.dispose(); // Only keep the clones to avoid leaking } else { - output.keras_mask = outputMask; + output.kerasMask = outputMask; } } diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index e86cb565bb0..134ff577ff8 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -813,8 +813,8 @@ export class MultiHeadAttention extends Layer { return tidy(() => { let autoMask: Tensor; - const queryMask = query.keras_mask; - const valueMask = value.keras_mask; + const queryMask = query.kerasMask; + const valueMask = value.kerasMask; if (queryMask != null) { autoMask = queryMask.expandDims(2); // Shape is [B, T, 1] } From 682227871e28096d1701915629f0f4095f9100bb Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 12 Sep 2023 14:52:21 -0700 Subject: [PATCH 5/5] Refactor mask computation into a separate function --- tfjs-layers/src/base_callbacks.ts | 3 +- tfjs-layers/src/engine/topology.ts | 104 +++++++++++++++---------- tfjs-layers/src/utils/generic_utils.ts | 2 +- 3 files changed, 65 insertions(+), 44 deletions(-) diff --git a/tfjs-layers/src/base_callbacks.ts b/tfjs-layers/src/base_callbacks.ts index 50eea0c8a56..e79cbf11eb2 100644 --- a/tfjs-layers/src/base_callbacks.ts +++ b/tfjs-layers/src/base_callbacks.ts @@ -488,7 +488,8 @@ export function standardizeCallbacks( } // Convert custom callback configs to custom callback objects. const callbackConfigs = - generic_utils.toList(callbacks) as CustomCallbackArgs[]; + generic_utils.toList( + callbacks) as CustomCallbackArgs[]; return callbackConfigs.map( callbackConfig => new CustomCallback(callbackConfig, yieldEvery)); } diff --git a/tfjs-layers/src/engine/topology.ts b/tfjs-layers/src/engine/topology.ts index 85ba63ccd66..ffa581560ba 100644 --- a/tfjs-layers/src/engine/topology.ts +++ b/tfjs-layers/src/engine/topology.ts @@ -751,19 +751,19 @@ export abstract class Layer extends serialization.Serializable { */ protected assertInputCompatibility(inputs: Tensor|Tensor[]|SymbolicTensor| SymbolicTensor[]): void { - inputs = generic_utils.toList(inputs); + const inputsList = generic_utils.toList(inputs); if (this.inputSpec == null || this.inputSpec.length === 0) { return; } const inputSpec = generic_utils.toList(this.inputSpec); - if (inputs.length !== inputSpec.length) { + if (inputsList.length !== inputSpec.length) { throw new ValueError( `Layer ${this.name} expects ${inputSpec.length} inputs, ` + - `but it received ${inputs.length} input tensors. ` + + `but it received ${inputsList.length} input tensors. ` + `Input received: ${inputs}`); } - for (let inputIndex = 0; inputIndex < inputs.length; inputIndex++) { - const x = inputs[inputIndex]; + for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) { + const x = inputsList[inputIndex]; const spec: InputSpec = inputSpec[inputIndex]; if (spec == null) { continue; @@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable { // Ensure inputs are all the same type. const inputsList = generic_utils.toList(inputs); - let allAreSymbolic = true; - for (const input of inputsList) { - if (!(input instanceof SymbolicTensor)) { - allAreSymbolic = false; - break; - } - } - let noneAreSymbolic = true; - for (const input of inputsList) { - if (input instanceof SymbolicTensor) { - noneAreSymbolic = false; - break; - } - } + const allAreSymbolic = checkAllSymbolic(inputs); + const noneAreSymbolic = checkNoneSymbolic(inputs); if (allAreSymbolic === noneAreSymbolic) { throw new ValueError( @@ -1017,31 +1005,12 @@ export abstract class Layer extends serialization.Serializable { // Actually call the layer, collecting output(s), mask(s), and shape(s). if (noneAreSymbolic) { - let output = this.call(inputs as Tensor | Tensor[], kwargs); + let output = this.call(inputs, kwargs); // Apply masks to the output tensors if the layer supports it. if (this.supportsMasking) { // TODO(mattsoulanille): pass the input tensors' masks to computeMask - const outputMask = this.computeMask(inputs as Tensor | Tensor[]); - if (output instanceof Array && outputMask instanceof Array) { - if (output.length !== outputMask.length) { - throw new Error(`${this.name} output ${output.length} tensors ` - + `but ${outputMask.length} masks for those tensors`); - } - for (let i = 0; i < output.length; i++) { - output[i].kerasMask = outputMask[i]; - } - } else if (outputMask instanceof Array) { - throw new Error(`{this.name} output a single tensor ` - + `but ${outputMask.length} masks`); - } else if (output instanceof Array) { - for (const out of output) { - out.kerasMask = outputMask.clone(); - } - outputMask.dispose(); // Only keep the clones to avoid leaking - } else { - output.kerasMask = outputMask; - } + this.setMaskMetadata(inputs, output); } // If the layer returns tensors from its inputs, unmodified, @@ -1097,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable { If the input tensor(s) had no previous history, this does nothing. */ - this.addInboundNode( - inputs as SymbolicTensor | SymbolicTensor[], output, null, null, + this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs); this._refCount++; @@ -1419,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable { return mask; } + private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[], + previousMask?: Tensor|Tensor[]): void { + if (!this.supportsMasking) { + return; + } + + const outputMasks = this.computeMask(inputs, previousMask); + if (outputs instanceof Array && outputMasks instanceof Array) { + if (outputs.length !== outputMasks.length) { + throw new Error(`${this.name} outputs ${outputs.length} tensors ` + + `but ${outputMasks.length} masks for those tensors`); + } + for (let i = 0; i < outputs.length; i++) { + outputs[i].kerasMask = outputMasks[i]; + } + } else if (outputMasks instanceof Array) { + throw new Error(`{this.name} outputs a single tensor ` + + `but ${outputMasks.length} masks`); + } else if (outputs instanceof Array) { + throw new Error(`{this.name} outputs ${outputs.length} tensors ` + + `but only one mask`); + } else { + outputs.kerasMask = outputMasks; + } + } + /** * Internal method to create an inbound node for the layer. * @@ -1666,3 +1660,29 @@ export function getSourceInputs( } } } + +type MaybeSymbolic = SymbolicTensor | Tensor; + +function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[] + ): tensors is SymbolicTensor | SymbolicTensor[] { + let allAreSymbolic = true; + for (const tensor of generic_utils.toList(tensors)) { + if (!(tensor instanceof SymbolicTensor)) { + allAreSymbolic = false; + break; + } + } + return allAreSymbolic; +} + +function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[] + ): tensors is Tensor | Tensor[] { + let noneAreSymbolic = true; + for (const tensor of generic_utils.toList(tensors)) { + if (tensor instanceof SymbolicTensor) { + noneAreSymbolic = false; + break; + } + } + return noneAreSymbolic; +} diff --git a/tfjs-layers/src/utils/generic_utils.ts b/tfjs-layers/src/utils/generic_utils.ts index 2d06342c8e1..c128097f5ed 100644 --- a/tfjs-layers/src/utils/generic_utils.ts +++ b/tfjs-layers/src/utils/generic_utils.ts @@ -76,7 +76,7 @@ export function singletonOrArray(xs: T[]): T|T[] { * @param x target object to be normalized. */ // tslint:disable-next-line:no-any -export function toList(x: any): any[] { +export function toList(x: T|T[]): T[] { if (Array.isArray(x)) { return x; }