-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make MultiHeadAttention use masks from query and value tensors #7951
Conversation
tfjs-layers/src/engine/topology.ts
Outdated
for (let i = 0; i < output.length; i++) { | ||
output[i].kerasMask = outputMask[i]; | ||
} | ||
} else if (outputMask instanceof Array) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the array only contains one mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that should be an error. If there's only one mask for all the tensors, it should be returned as a tensor
instead of a [tensor]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, keras does not seem to broadcast masks at all. Each output tensor needs its own mask:
https://github.com/keras-team/keras/blob/master/keras/engine/base_layer.py#L2893-L2898
if (useCausalMask) { | ||
// the shape of the causal mask is [1, T, S] | ||
const mask = this.computeCausalMask(query, value); | ||
autoMask = mask; | ||
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is this associated with the Topology computeMask logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Earlier layers (Embedding) have computeMask called to compute the mask for their output tensors. This layer uses those masks.
111048c
to
6822278
Compare
type MaybeSymbolic = SymbolicTensor | Tensor; | ||
|
||
function checkAllSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[] | ||
): tensors is SymbolicTensor | SymbolicTensor[] { | ||
let allAreSymbolic = true; | ||
for (const tensor of generic_utils.toList(tensors)) { | ||
if (!(tensor instanceof SymbolicTensor)) { | ||
allAreSymbolic = false; | ||
break; | ||
} | ||
} | ||
return allAreSymbolic; | ||
} | ||
|
||
function checkNoneSymbolic(tensors: MaybeSymbolic | MaybeSymbolic[] | ||
): tensors is Tensor | Tensor[] { | ||
let noneAreSymbolic = true; | ||
for (const tensor of generic_utils.toList(tensors)) { | ||
if (tensor instanceof SymbolicTensor) { | ||
noneAreSymbolic = false; | ||
break; | ||
} | ||
} | ||
return noneAreSymbolic; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this here so I could write them as type guards (tensors is Tensor...
)
@@ -188,101 +188,6 @@ describe('MultiHeadAttention', () => { | |||
expectTensorsNotClose(queryKernel, outputKernel, 1e-6); | |||
}); | |||
|
|||
describeMathCPU('High Dimensional Attention', () => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved
/** | ||
* Test that the value and causal masks are taken into account. | ||
*/ | ||
function testValueMask(testcaseName: string, useCausalMask: boolean) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored to explicitly declare both tests instead of using a loop to declare them.
@@ -482,6 +395,101 @@ describe('MultiHeadAttention', () => { | |||
// TODO(pforderique): Test serialization. | |||
}); | |||
|
|||
describeMathCPU('High Dimensional Attention', () => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved some tests from above to here.
Add optional Keras masks to tfjs tensors. Enable them for tfjs-layers on layers that emit them. Use the masks of query and value input tensors in MultiHeadAttention to compute the correct mask automatically.
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.