-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DML EP] Enable more MHA masks #17882
Conversation
The implementation lgtm. Thank you for adding this. The contrib op schema for MultiheadAttention says that it only accepts rank <= 2 tensor for mask as mentioned here:
At the same time, I also see that there is no validation on the rank of this tensor in MultiHeadAttentionTypeAndShapeInference. But CPU EP does throw exception if mask tensor rank > 2. So I think we should also update the contrib ops doc and may be create a bug for the CPU EP? |
The contrib op definitions are being updated in another branch in parallel (including tests), but it will all line up for 1.16.2. Note: I don't own the other branch (and it's a giant feature branch with hundreds of modified files), which is why I made this separate PR. |
Understood. Thanks again! |
Those masks are used for MHA in LLaMA.
Those masks are used for MHA in LLaMA.
Those masks are used for MHA in LLaMA.
Those masks are used for MHA in LLaMA.