Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MultiHeadAttention use masks from query and value tensors #7951

Merged
merged 6 commits into from
Sep 12, 2023

Conversation

mattsoulanille
Copy link
Member

Add optional Keras masks to tfjs tensors. Enable them for tfjs-layers on layers that emit them. Use the masks of query and value input tensors in MultiHeadAttention to compute the correct mask automatically.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.

for (let i = 0; i < output.length; i++) {
output[i].kerasMask = outputMask[i];
}
} else if (outputMask instanceof Array) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the array only contains one mask?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that should be an error. If there's only one mask for all the tensors, it should be returned as a tensor instead of a [tensor].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, keras does not seem to broadcast masks at all. Each output tensor needs its own mask:
https://github.com/keras-team/keras/blob/master/keras/engine/base_layer.py#L2893-L2898

if (useCausalMask) {
// the shape of the causal mask is [1, T, S]
const mask = this.computeCausalMask(query, value);
autoMask = mask;
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this associated with the Topology computeMask logic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier layers (Embedding) have computeMask called to compute the mask for their output tensors. This layer uses those masks.

Comment on lines +1664 to +1688
type MaybeSymbolic = SymbolicTensor | Tensor;

function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is SymbolicTensor | SymbolicTensor[] {
let allAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (!(tensor instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
return allAreSymbolic;
}

function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[]
): tensors is Tensor | Tensor[] {
let noneAreSymbolic = true;
for (const tensor of generic_utils.toList(tensors)) {
if (tensor instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
return noneAreSymbolic;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this here so I could write them as type guards (tensors is Tensor...)

@@ -188,101 +188,6 @@ describe('MultiHeadAttention', () => {
expectTensorsNotClose(queryKernel, outputKernel, 1e-6);
});

describeMathCPU('High Dimensional Attention', () => {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved

/**
* Test that the value and causal masks are taken into account.
*/
function testValueMask(testcaseName: string, useCausalMask: boolean) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to explicitly declare both tests instead of using a loop to declare them.

@@ -482,6 +395,101 @@ describe('MultiHeadAttention', () => {
// TODO(pforderique): Test serialization.
});

describeMathCPU('High Dimensional Attention', () => {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved some tests from above to here.

@mattsoulanille mattsoulanille requested review from fengwuyao and removed request for Linchenn September 12, 2023 22:29
@pyu10055 pyu10055 merged commit 8879e72 into tensorflow:master Sep 12, 2023
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants