Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MultiHeadAttention use masks from query and value tensors #7951

Merged
merged 6 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tfjs-core/src/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ export class Tensor<R extends Rank = Rank> implements TensorInfo {
kept = false;
/** The id of the scope this tensor is being tracked in. */
scopeId: number;
/** The keras mask that some keras layers attach to the tensor */
kerasMask?: Tensor;

/**
* Number of elements to skip in each dimension when indexing. See
Expand Down Expand Up @@ -442,6 +444,9 @@ export class Tensor<R extends Rank = Rank> implements TensorInfo {
if (this.isDisposed) {
return;
}
if (this.kerasMask) {
this.kerasMask.dispose();
}
trackerFn().disposeTensor(this);
this.isDisposedInternal = true;
}
Expand Down
3 changes: 2 additions & 1 deletion tfjs-layers/src/base_callbacks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ export function standardizeCallbacks(
}
// Convert custom callback configs to custom callback objects.
const callbackConfigs =
generic_utils.toList(callbacks) as CustomCallbackArgs[];
generic_utils.toList<BaseCallback | CustomCallbackArgs>(
callbacks) as CustomCallbackArgs[];
return callbackConfigs.map(
callbackConfig => new CustomCallback(callbackConfig, yieldEvery));
}
Expand Down
90 changes: 67 additions & 23 deletions tfjs-layers/src/engine/topology.ts
Original file line number Diff line number Diff line change
Expand Up @@ -751,19 +751,19 @@ export abstract class Layer extends serialization.Serializable {
*/
protected assertInputCompatibility(inputs: Tensor|Tensor[]|SymbolicTensor|
SymbolicTensor[]): void {
inputs = generic_utils.toList(inputs);
const inputsList = generic_utils.toList(inputs);
if (this.inputSpec == null || this.inputSpec.length === 0) {
return;
}
const inputSpec = generic_utils.toList(this.inputSpec);
if (inputs.length !== inputSpec.length) {
if (inputsList.length !== inputSpec.length) {
throw new ValueError(
`Layer ${this.name} expects ${inputSpec.length} inputs, ` +
`but it received ${inputs.length} input tensors. ` +
`but it received ${inputsList.length} input tensors. ` +
`Input received: ${inputs}`);
}
for (let inputIndex = 0; inputIndex < inputs.length; inputIndex++) {
const x = inputs[inputIndex];
for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) {
const x = inputsList[inputIndex];
const spec: InputSpec = inputSpec[inputIndex];
if (spec == null) {
continue;
Expand Down Expand Up @@ -954,20 +954,8 @@ export abstract class Layer extends serialization.Serializable {
// Ensure inputs are all the same type.
const inputsList = generic_utils.toList(inputs);

let allAreSymbolic = true;
for (const input of inputsList) {
if (!(input instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
let noneAreSymbolic = true;
for (const input of inputsList) {
if (input instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
const allAreSymbolic = checkAllSymbolic(inputs);
const noneAreSymbolic = checkNoneSymbolic(inputs);

if (allAreSymbolic === noneAreSymbolic) {
throw new ValueError(
Expand Down Expand Up @@ -1017,8 +1005,13 @@ export abstract class Layer extends serialization.Serializable {

// Actually call the layer, collecting output(s), mask(s), and shape(s).
if (noneAreSymbolic) {
let output = this.call(inputs as Tensor | Tensor[], kwargs);
// TODO(michaelterry): Compute the outputMask
let output = this.call(inputs, kwargs);

// Apply masks to the output tensors if the layer supports it.
if (this.supportsMasking) {
// TODO(mattsoulanille): pass the input tensors' masks to computeMask
this.setMaskMetadata(inputs, output);
}

// If the layer returns tensors from its inputs, unmodified,
// we copy them to avoid loss of tensor metadata.
Expand Down Expand Up @@ -1073,8 +1066,7 @@ export abstract class Layer extends serialization.Serializable {
If the input tensor(s) had no previous history,
this does nothing.
*/
this.addInboundNode(
inputs as SymbolicTensor | SymbolicTensor[], output, null, null,
this.addInboundNode(inputs, output, null, null,
inputShape, outputShape, kwargs);
this._refCount++;

Expand Down Expand Up @@ -1395,6 +1387,32 @@ export abstract class Layer extends serialization.Serializable {
return mask;
}

private setMaskMetadata(inputs: Tensor|Tensor[], outputs: Tensor|Tensor[],
previousMask?: Tensor|Tensor[]): void {
if (!this.supportsMasking) {
return;
}

const outputMasks = this.computeMask(inputs, previousMask);
if (outputs instanceof Array && outputMasks instanceof Array) {
if (outputs.length !== outputMasks.length) {
throw new Error(`${this.name} outputs ${outputs.length} tensors `
+ `but ${outputMasks.length} masks for those tensors`);
}
for (let i = 0; i < outputs.length; i++) {
outputs[i].kerasMask = outputMasks[i];
}
} else if (outputMasks instanceof Array) {
throw new Error(`{this.name} outputs a single tensor `
+ `but ${outputMasks.length} masks`);
} else if (outputs instanceof Array) {
throw new Error(`{this.name} outputs ${outputs.length} tensors `
+ `but only one mask`);
} else {
outputs.kerasMask = outputMasks;
}
}

/**
* Internal method to create an inbound node for the layer.
*
Expand Down Expand Up @@ -1642,3 +1660,29 @@ export function getSourceInputs(
}
}
}

type MaybeSymbolic = SymbolicTensor | Tensor;

function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is SymbolicTensor | SymbolicTensor[] {
let allAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (!(tensor instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
return allAreSymbolic;
}

function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is Tensor | Tensor[] {
let noneAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (tensor instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
return noneAreSymbolic;
}
Comment on lines +1664 to +1688
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this here so I could write them as type guards (tensors is Tensor...)

14 changes: 11 additions & 3 deletions tfjs-layers/src/layers/nlp/multihead_attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
*/

/* Original source: keras/layers/attention/multi_head_attention.py */
import { Tensor, einsum, linalg, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core';
import { Tensor, einsum, linalg, logicalAnd, mul, ones, serialization, tidy } from '@tensorflow/tfjs-core';
// tslint:disable-next-line: no-imports-from-dist
import { arraysEqual } from '@tensorflow/tfjs-core/dist/util_base';

Expand Down Expand Up @@ -813,12 +813,20 @@ export class MultiHeadAttention extends Layer {
return tidy(() => {
let autoMask: Tensor;

const queryMask = query.kerasMask;
const valueMask = value.kerasMask;
if (queryMask != null) {
autoMask = queryMask.expandDims(2); // Shape is [B, T, 1]
}
if (valueMask != null) {
const mask = valueMask.expandDims(1); // Shape is [B, 1, S]
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;
}
if (useCausalMask) {
// the shape of the causal mask is [1, T, S]
const mask = this.computeCausalMask(query, value);
autoMask = mask;
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this associated with the Topology computeMask logic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier layers (Embedding) have computeMask called to compute the mask for their output tensors. This layer uses those masks.

}

if (autoMask != null) {
// Merge attentionMask & automatic mask, to shape [B, T, S]
attentionMask = attentionMask ?
Expand Down
Loading