diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 59cc89c5bd8..137d3b48705 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -275,6 +275,8 @@ export class Tensor implements TensorInfo { kept = false; /** The id of the scope this tensor is being tracked in. */ scopeId: number; + /** The keras mask that some keras layers attach to the tensor */ + kerasMask?: Tensor; /** * Number of elements to skip in each dimension when indexing. See @@ -442,6 +444,9 @@ export class Tensor implements TensorInfo { if (this.isDisposed) { return; } + if (this.kerasMask) { + this.kerasMask.dispose(); + } trackerFn().disposeTensor(this); this.isDisposedInternal = true; } diff --git a/tfjs-layers/src/base_callbacks.ts b/tfjs-layers/src/base_callbacks.ts index 50eea0c8a56..e79cbf11eb2 100644 --- a/tfjs-layers/src/base_callbacks.ts +++ b/tfjs-layers/src/base_callbacks.ts @@ -488,7 +488,8 @@ export function standardizeCallbacks( } // Convert custom callback configs to custom callback objects. const callbackConfigs = - generic_utils.toList(callbacks) as CustomCallbackArgs[]; + generic_utils.toList( + callbacks) as CustomCallbackArgs[]; return callbackConfigs.map( callbackConfig => new CustomCallback(callbackConfig, yieldEvery)); } diff --git a/tfjs-layers/src/engine/topology.ts b/tfjs-layers/src/engine/topology.ts index 37b64663750..ffa581560ba 100644 --- a/tfjs-layers/src/engine/topology.ts +++ b/tfjs-layers/src/engine/topology.ts @@ -751,19 +751,19 @@ export abstract class Layer extends serialization.Serializable { */ protected assertInputCompatibility(inputs: Tensor|Tensor[]|SymbolicTensor| SymbolicTensor[]): void { - inputs = generic_utils.toList(inputs); + const inputsList = generic_utils.toList(inputs); if (this.inputSpec == null || this.inputSpec.length === 0) { return; } const inputSpec = generic_utils.toList(this.inputSpec); - if (inputs.length !== inputSpec.length) { + if (inputsList.length !== inputSpec.length) { throw new ValueError( `Layer ${this.name} expects ${inputSpec.length} inputs, ` + - `but it received ${inputs.length} input tensors. ` + + `but it received ${inputsList.length} input tensors. ` + `Input received: ${inputs}`); } - for (let inputIndex = 0; inputIndex < inputs.length; inputIndex++) { - const x = inputs[inputIndex]; + for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) { + const x = inputsList[inputIndex]; const spec: InputSpec = inputSpec[inputIndex]; if (spec == null) { continue; @@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable { // Ensure inputs are all the same type. const inputsList = generic_utils.toList(inputs); - let allAreSymbolic = true; - for (const input of inputsList) { - if (!(input instanceof SymbolicTensor)) { - allAreSymbolic = false; - break; - } - } - let noneAreSymbolic = true; - for (const input of inputsList) { - if (input instanceof SymbolicTensor) { - noneAreSymbolic = false; - break; - } - } + const allAreSymbolic = checkAllSymbolic(inputs); + const noneAreSymbolic = checkNoneSymbolic(inputs); if (allAreSymbolic === noneAreSymbolic) { throw new ValueError( @@ -1017,8 +1005,13 @@ export abstract class Layer extends serialization.Serializable { // Actually call the layer, collecting output(s), mask(s), and shape(s). if (noneAreSymbolic) { - let output = this.call(inputs as Tensor | Tensor[], kwargs); - // TODO(michaelterry): Compute the outputMask + let output = this.call(inputs, kwargs); + + // Apply masks to the output tensors if the layer supports it. + if (this.supportsMasking) { + // TODO(mattsoulanille): pass the input tensors' masks to computeMask + this.setMaskMetadata(inputs, output); + } // If the layer returns tensors from its inputs, unmodified, // we copy them to avoid loss of tensor metadata. @@ -1073,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable { If the input tensor(s) had no previous history, this does nothing. */ - this.addInboundNode( - inputs as SymbolicTensor | SymbolicTensor[], output, null, null, + this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs); this._refCount++; @@ -1395,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable { return mask; } + private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[], + previousMask?: Tensor|Tensor[]): void { + if (!this.supportsMasking) { + return; + } + + const outputMasks = this.computeMask(inputs, previousMask); + if (outputs instanceof Array && outputMasks instanceof Array) { + if (outputs.length !== outputMasks.length) { + throw new Error(`${this.name} outputs ${outputs.length} tensors ` + + `but ${outputMasks.length} masks for those tensors`); + } + for (let i = 0; i < outputs.length; i++) { + outputs[i].kerasMask = outputMasks[i]; + } + } else if (outputMasks instanceof Array) { + throw new Error(`{this.name} outputs a single tensor ` + + `but ${outputMasks.length} masks`); + } else if (outputs instanceof Array) { + throw new Error(`{this.name} outputs ${outputs.length} tensors ` + + `but only one mask`); + } else { + outputs.kerasMask = outputMasks; + } + } + /** * Internal method to create an inbound node for the layer. * @@ -1642,3 +1660,29 @@ export function getSourceInputs( } } } + +type MaybeSymbolic = SymbolicTensor | Tensor; + +function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[] + ): tensors is SymbolicTensor | SymbolicTensor[] { + let allAreSymbolic = true; + for (const tensor of generic_utils.toList(tensors)) { + if (!(tensor instanceof SymbolicTensor)) { + allAreSymbolic = false; + break; + } + } + return allAreSymbolic; +} + +function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[] + ): tensors is Tensor | Tensor[] { + let noneAreSymbolic = true; + for (const tensor of generic_utils.toList(tensors)) { + if (tensor instanceof SymbolicTensor) { + noneAreSymbolic = false; + break; + } + } + return noneAreSymbolic; +} diff --git a/tfjs-layers/src/layers/nlp/multihead_attention.ts b/tfjs-layers/src/layers/nlp/multihead_attention.ts index b89d472f27d..134ff577ff8 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention.ts @@ -20,7 +20,7 @@ */ /* Original source: keras/layers/attention/multi_head_attention.py */ -import { Tensor, einsum, linalg, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core'; +import { Tensor, einsum, linalg, logicalAnd, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core'; // tslint:disable-next-line: no-imports-from-dist import { arraysEqual } from '@tensorflow/tfjs-core/dist/util_base'; @@ -813,12 +813,20 @@ export class MultiHeadAttention extends Layer { return tidy(() => { let autoMask: Tensor; + const queryMask = query.kerasMask; + const valueMask = value.kerasMask; + if (queryMask != null) { + autoMask = queryMask.expandDims(2); // Shape is [B, T, 1] + } + if (valueMask != null) { + const mask = valueMask.expandDims(1); // Shape is [B, 1, S] + autoMask = autoMask ? logicalAnd(autoMask, mask) : mask; + } if (useCausalMask) { // the shape of the causal mask is [1, T, S] const mask = this.computeCausalMask(query, value); - autoMask = mask; + autoMask = autoMask ? logicalAnd(autoMask, mask) : mask; } - if (autoMask != null) { // Merge attentionMask & automatic mask, to shape [B, T, S] attentionMask = attentionMask ? diff --git a/tfjs-layers/src/layers/nlp/multihead_attention_test.ts b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts index a389d1ab52c..1cc5d77f0e5 100644 --- a/tfjs-layers/src/layers/nlp/multihead_attention_test.ts +++ b/tfjs-layers/src/layers/nlp/multihead_attention_test.ts @@ -25,10 +25,10 @@ import { TruncatedNormal } from '../../initializers'; import { input } from '../../exports'; import { Shape } from '../../keras_format/common'; import { MultiHeadAttention } from './multihead_attention'; -import { describeMathCPU, expectTensorsClose, expectTensorsNotClose } from '../../utils/test_utils'; +import { describeMathCPU, describeMathCPUAndGPU, expectTensorsClose, expectTensorsNotClose } from '../../utils/test_utils'; import { Embedding } from '../embeddings'; -describe('MultiHeadAttention', () => { +describeMathCPUAndGPU('MultiHeadAttention', () => { describe('Non Masked Attention', () => { interface NonMaskedAttentionArgs { @@ -188,101 +188,6 @@ describe('MultiHeadAttention', () => { expectTensorsNotClose(queryKernel, outputKernel, 1e-6); }); - describeMathCPU('High Dimensional Attention', () => { - interface HighDimAttentionArgs { - testcaseName: string; - qDims: Shape; - vDims: Shape; - maskDims: Shape; - attentionAxes: number[]; - } - /** - * Test with high dimensional inputs. - */ - function testHighDimAttention({ - testcaseName, qDims, vDims, maskDims, attentionAxes, - }: HighDimAttentionArgs) { - it(testcaseName, () => { - const testLayer = new MultiHeadAttention({ - numHeads: 2, keyDim: 2, attentionAxes, - }); - const batchSize = 3; - const hiddenSize = 8; - // Generate data for the input (non-mask) tensors. - const queryShape = [batchSize].concat(qDims).concat(hiddenSize); - const valueShape = [batchSize].concat(vDims).concat(hiddenSize); - const maskShape = [batchSize].concat(maskDims); - const query = randomUniform(queryShape, 0, 10); - const value = randomUniform(valueShape, 0, 10); - - // Invoke the data with a random set of mask data. This should mask at - // least one element. - const maskData = randomUniformInt(maskShape, 0, 2).asType('bool'); - - // Invoke the same data, but with a null mask (where no elements are - // masked). - const nullMaskData = ones(maskShape); - - // Because one data is masked and one is not, the outputs should not be - // the same. - - const outputWithMask = testLayer.call( - query, {value, attentionMask: maskData}); - const outputWithNullMask = testLayer.call( - query, {value, attentionMask: nullMaskData}); - - expectTensorsNotClose(outputWithMask, outputWithNullMask); - }); - } - const params: HighDimAttentionArgs[] = [ - { - testcaseName: '4d_inputs_1freebatch_mask2', - qDims: [3, 4], - vDims: [3, 2], - maskDims: [4, 2], - attentionAxes: [2], - }, - { - testcaseName: '4d_inputs_1freebatch_mask3', - qDims: [3, 4], - vDims: [3, 2], - maskDims: [3, 4, 2], - attentionAxes: [2], - }, - { - testcaseName: '4d_inputs_1freebatch_mask4', - qDims: [3, 4], - vDims: [3, 2], - maskDims: [3, 2, 4, 2], - attentionAxes: [2], - }, - { - testcaseName: '4D_inputs_2D_attention', - qDims: [3, 4], - vDims: [3, 2], - maskDims: [3, 4, 3, 2], - attentionAxes: [1, 2], - }, - { - testcaseName: '5D_inputs_2D_attention', - qDims: [5, 3, 4], - vDims: [5, 3, 2], - maskDims: [3, 4, 3, 2], - attentionAxes: [2, 3], - }, - { - testcaseName: '5D_inputs_2D_attention_fullmask', - qDims: [5, 3, 4], - vDims: [5, 3, 2], - maskDims: [5, 3, 4, 3, 2], - attentionAxes: [2, 3], - }, - ]; - for (const param of params) { - testHighDimAttention(param); - } - }); - it('dropout', () => { const testLayer = new MultiHeadAttention({ numHeads: 2, @@ -298,52 +203,60 @@ describe('MultiHeadAttention', () => { expectTensorsNotClose(trainOut, testOut); }); - describe('Casual Mask Value', () => { + describe('Causal Mask Value', () => { + let testLayer: MultiHeadAttention; + let maskedQuery: Tensor; + let maskedValue: Tensor; + let mask: Tensor; + + beforeEach(() => { + testLayer = new MultiHeadAttention({numHeads: 2, keyDim: 2}); + const query = tensor2d([ + [1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0] + ]); + const maskedQueryLayer = new Embedding( + {inputDim: 4, outputDim: 8, maskZero: true}); + maskedQuery = maskedQueryLayer.apply(query) as Tensor; + const value = tensor2d([[5, 4, 0], [3, 0, 0], [2, 1, 1]]); + maskedValue = new Embedding( + {inputDim: 6, outputDim: 8, maskZero: true}).apply(value) as Tensor; + + mask = tensor([ + Array(3).fill([true, true, false]).concat( + Array(2).fill([false, false, false])), + Array(5).fill([true, false, false]), + [[true, true, true]].concat( + Array(4).fill([false, false, false])) + ]); + }); + /** * Test that the value and causal masks are taken into account. */ - function testValueMask(testcaseName: string, useCausalMask: boolean) { - it(testcaseName, () => { - const testLayer = new MultiHeadAttention({numHeads: 2, keyDim: 2}); - const query = tensor2d([ - [1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0] - ]); - const maskedQuery = new Embedding( - {inputDim: 4, outputDim: 8, maskZero: true}).apply(query) as Tensor; - const value = tensor2d([[5, 4, 0], [3, 0, 0], [2, 1, 1]]); - const maskedValue = new Embedding( - {inputDim: 6, outputDim: 8, maskZero: true}).apply(value) as Tensor; - - const output = testLayer.call( - maskedQuery, {value: maskedValue, useCausalMask: true}); - - let mask = tensor([ - Array(3).fill([true, true, false]).concat( - Array(2).fill([false, false, false])), - Array(5).fill([true, false, false]), - [[true, true, true]].concat( - Array(4).fill([false, false, false])) - ]); - if (useCausalMask) { - mask = mask.logicalAnd(tensor([ - [[true, false, false], [true, true, false]].concat( - [[true, true, true], [true, true, true], [true, true, true]]) - ])); - } + it('causal', () => { + const output = testLayer.call( + maskedQuery, {value: maskedValue, useCausalMask: true}); - const outputWithManualMask = testLayer.call( - maskedQuery, {value: maskedValue, attentionMask: mask}); + mask = mask.logicalAnd(tensor([ + [[true, false, false], [true, true, false]].concat( + [[true, true, true], [true, true, true], [true, true, true]]) + ])); - expectTensorsClose(output, outputWithManualMask); - }); - } + const outputWithManualMask = testLayer.call( + maskedQuery, {value: maskedValue, attentionMask: mask}); - const params: Array<[string, boolean]> = [ - ['casual', true], ['not_casual', false] - ]; - for (const [testName, useMask] of params) { - testValueMask(testName, useMask); - } + expectTensorsClose(output, outputWithManualMask); + }); + + it('not_causal', () => { + const output = testLayer.call( + maskedQuery, {value: maskedValue, useCausalMask: false}); + + const outputWithManualMask = testLayer.call( + maskedQuery, {value: maskedValue, attentionMask: mask}); + + expectTensorsClose(output, outputWithManualMask); + }); }); describe('Compute Output Shape', () => { @@ -482,6 +395,101 @@ describe('MultiHeadAttention', () => { // TODO(pforderique): Test serialization. }); +describeMathCPU('High Dimensional Attention', () => { + interface HighDimAttentionArgs { + testcaseName: string; + qDims: Shape; + vDims: Shape; + maskDims: Shape; + attentionAxes: number[]; + } + /** + * Test with high dimensional inputs. + */ + function testHighDimAttention({ + testcaseName, qDims, vDims, maskDims, attentionAxes, + }: HighDimAttentionArgs) { + it(testcaseName, () => { + const testLayer = new MultiHeadAttention({ + numHeads: 2, keyDim: 2, attentionAxes, + }); + const batchSize = 3; + const hiddenSize = 8; + // Generate data for the input (non-mask) tensors. + const queryShape = [batchSize].concat(qDims).concat(hiddenSize); + const valueShape = [batchSize].concat(vDims).concat(hiddenSize); + const maskShape = [batchSize].concat(maskDims); + const query = randomUniform(queryShape, 0, 10); + const value = randomUniform(valueShape, 0, 10); + + // Invoke the data with a random set of mask data. This should mask at + // least one element. + const maskData = randomUniformInt(maskShape, 0, 2).asType('bool'); + + // Invoke the same data, but with a null mask (where no elements are + // masked). + const nullMaskData = ones(maskShape); + + // Because one data is masked and one is not, the outputs should not be + // the same. + + const outputWithMask = testLayer.call( + query, {value, attentionMask: maskData}); + const outputWithNullMask = testLayer.call( + query, {value, attentionMask: nullMaskData}); + + expectTensorsNotClose(outputWithMask, outputWithNullMask); + }); + } + const params: HighDimAttentionArgs[] = [ + { + testcaseName: '4d_inputs_1freebatch_mask2', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [4, 2], + attentionAxes: [2], + }, + { + testcaseName: '4d_inputs_1freebatch_mask3', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [3, 4, 2], + attentionAxes: [2], + }, + { + testcaseName: '4d_inputs_1freebatch_mask4', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [3, 2, 4, 2], + attentionAxes: [2], + }, + { + testcaseName: '4D_inputs_2D_attention', + qDims: [3, 4], + vDims: [3, 2], + maskDims: [3, 4, 3, 2], + attentionAxes: [1, 2], + }, + { + testcaseName: '5D_inputs_2D_attention', + qDims: [5, 3, 4], + vDims: [5, 3, 2], + maskDims: [3, 4, 3, 2], + attentionAxes: [2, 3], + }, + { + testcaseName: '5D_inputs_2D_attention_fullmask', + qDims: [5, 3, 4], + vDims: [5, 3, 2], + maskDims: [5, 3, 4, 3, 2], + attentionAxes: [2, 3], + }, + ]; + for (const param of params) { + testHighDimAttention(param); + } +}); + class SubclassAttention extends MultiHeadAttention { protected override buildAttention(qkvRank: number) {} diff --git a/tfjs-layers/src/utils/generic_utils.ts b/tfjs-layers/src/utils/generic_utils.ts index 2d06342c8e1..c128097f5ed 100644 --- a/tfjs-layers/src/utils/generic_utils.ts +++ b/tfjs-layers/src/utils/generic_utils.ts @@ -76,7 +76,7 @@ export function singletonOrArray(xs: T[]): T|T[] { * @param x target object to be normalized. */ // tslint:disable-next-line:no-any -export function toList(x: any): any[] { +export function toList(x: T|T[]): T[] { if (Array.isArray(x)) { return x; }