Skip to content

Commit

Permalink
Make MultiHeadAttention use masks from query and value tensors (#7951)
Browse files Browse the repository at this point in the history
BUG
* Use describeWithFlags

* Separate tests out of for loop

* Make MHA use masks from query and value tensors

* Fix lint

* Refactor mask computation into a separate function
  • Loading branch information
mattsoulanille authored Sep 12, 2023
1 parent f44e224 commit 8879e72
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 165 deletions.
5 changes: 5 additions & 0 deletions tfjs-core/src/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ export class Tensor<R extends Rank = Rank> implements TensorInfo {
kept = false;
/** The id of the scope this tensor is being tracked in. */
scopeId: number;
/** The keras mask that some keras layers attach to the tensor */
kerasMask?: Tensor;

/**
* Number of elements to skip in each dimension when indexing. See
Expand Down Expand Up @@ -442,6 +444,9 @@ export class Tensor<R extends Rank = Rank> implements TensorInfo {
if (this.isDisposed) {
return;
}
if (this.kerasMask) {
this.kerasMask.dispose();
}
trackerFn().disposeTensor(this);
this.isDisposedInternal = true;
}
Expand Down
3 changes: 2 additions & 1 deletion tfjs-layers/src/base_callbacks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ export function standardizeCallbacks(
}
// Convert custom callback configs to custom callback objects.
const callbackConfigs =
generic_utils.toList(callbacks) as CustomCallbackArgs[];
generic_utils.toList<BaseCallback | CustomCallbackArgs>(
callbacks) as CustomCallbackArgs[];
return callbackConfigs.map(
callbackConfig => new CustomCallback(callbackConfig, yieldEvery));
}
Expand Down
90 changes: 67 additions & 23 deletions tfjs-layers/src/engine/topology.ts
Original file line number Diff line number Diff line change
Expand Up @@ -751,19 +751,19 @@ export abstract class Layer extends serialization.Serializable {
*/
protected assertInputCompatibility(inputs: Tensor|Tensor[]|SymbolicTensor|
SymbolicTensor[]): void {
inputs = generic_utils.toList(inputs);
const inputsList = generic_utils.toList(inputs);
if (this.inputSpec == null || this.inputSpec.length === 0) {
return;
}
const inputSpec = generic_utils.toList(this.inputSpec);
if (inputs.length !== inputSpec.length) {
if (inputsList.length !== inputSpec.length) {
throw new ValueError(
`Layer ${this.name} expects ${inputSpec.length} inputs, ` +
`but it received ${inputs.length} input tensors. ` +
`but it received ${inputsList.length} input tensors. ` +
`Input received: ${inputs}`);
}
for (let inputIndex = 0; inputIndex < inputs.length; inputIndex++) {
const x = inputs[inputIndex];
for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) {
const x = inputsList[inputIndex];
const spec: InputSpec = inputSpec[inputIndex];
if (spec == null) {
continue;
Expand Down Expand Up @@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable {
// Ensure inputs are all the same type.
const inputsList = generic_utils.toList(inputs);

let allAreSymbolic = true;
for (const input of inputsList) {
if (!(input instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
let noneAreSymbolic = true;
for (const input of inputsList) {
if (input instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
const allAreSymbolic = checkAllSymbolic(inputs);
const noneAreSymbolic = checkNoneSymbolic(inputs);

if (allAreSymbolic === noneAreSymbolic) {
throw new ValueError(
Expand Down Expand Up @@ -1017,8 +1005,13 @@ export abstract class Layer extends serialization.Serializable {

// Actually call the layer, collecting output(s), mask(s), and shape(s).
if (noneAreSymbolic) {
let output = this.call(inputs as Tensor | Tensor[], kwargs);
// TODO(michaelterry): Compute the outputMask
let output = this.call(inputs, kwargs);

// Apply masks to the output tensors if the layer supports it.
if (this.supportsMasking) {
// TODO(mattsoulanille): pass the input tensors' masks to computeMask
this.setMaskMetadata(inputs, output);
}

// If the layer returns tensors from its inputs, unmodified,
// we copy them to avoid loss of tensor metadata.
Expand Down Expand Up @@ -1073,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable {
If the input tensor(s) had no previous history,
this does nothing.
*/
this.addInboundNode(
inputs as SymbolicTensor | SymbolicTensor[], output, null, null,
this.addInboundNode(inputs, output, null, null,
inputShape, outputShape, kwargs);
this._refCount++;

Expand Down Expand Up @@ -1395,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable {
return mask;
}

private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[],
previousMask?: Tensor|Tensor[]): void {
if (!this.supportsMasking) {
return;
}

const outputMasks = this.computeMask(inputs, previousMask);
if (outputs instanceof Array && outputMasks instanceof Array) {
if (outputs.length !== outputMasks.length) {
throw new Error(`${this.name} outputs ${outputs.length} tensors `
+ `but ${outputMasks.length} masks for those tensors`);
}
for (let i = 0; i < outputs.length; i++) {
outputs[i].kerasMask = outputMasks[i];
}
} else if (outputMasks instanceof Array) {
throw new Error(`{this.name} outputs a single tensor `
+ `but ${outputMasks.length} masks`);
} else if (outputs instanceof Array) {
throw new Error(`{this.name} outputs ${outputs.length} tensors `
+ `but only one mask`);
} else {
outputs.kerasMask = outputMasks;
}
}

/**
* Internal method to create an inbound node for the layer.
*
Expand Down Expand Up @@ -1642,3 +1660,29 @@ export function getSourceInputs(
}
}
}

type MaybeSymbolic = SymbolicTensor | Tensor;

function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is SymbolicTensor | SymbolicTensor[] {
let allAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (!(tensor instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
return allAreSymbolic;
}

function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is Tensor | Tensor[] {
let noneAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (tensor instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
return noneAreSymbolic;
}
14 changes: 11 additions & 3 deletions tfjs-layers/src/layers/nlp/multihead_attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
*/

/* Original source: keras/layers/attention/multi_head_attention.py */
import { Tensor, einsum, linalg, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core';
import { Tensor, einsum, linalg, logicalAnd, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core';
// tslint:disable-next-line: no-imports-from-dist
import { arraysEqual } from '@tensorflow/tfjs-core/dist/util_base';

Expand Down Expand Up @@ -813,12 +813,20 @@ export class MultiHeadAttention extends Layer {
return tidy(() => {
let autoMask: Tensor;

const queryMask = query.kerasMask;
const valueMask = value.kerasMask;
if (queryMask != null) {
autoMask = queryMask.expandDims(2); // Shape is [B, T, 1]
}
if (valueMask != null) {
const mask = valueMask.expandDims(1); // Shape is [B, 1, S]
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;
}
if (useCausalMask) {
// the shape of the causal mask is [1, T, S]
const mask = this.computeCausalMask(query, value);
autoMask = mask;
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;
}

if (autoMask != null) {
// Merge attentionMask & automatic mask, to shape [B, T, S]
attentionMask = attentionMask ?
Expand Down
Loading

0 comments on commit 8879e72

Please sign in to comment.