Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML EP] Enable more MHA masks #17882

Merged
merged 1 commit into from
Oct 18, 2023
Merged

Conversation

PatriceVignola
Copy link
Contributor

Those masks are used for MHA in LLaMA.

@sumitsays
Copy link
Contributor

The implementation lgtm. Thank you for adding this.

The contrib op schema for MultiheadAttention says that it only accepts rank <= 2 tensor for mask as mentioned here:

"Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)",

At the same time, I also see that there is no validation on the rank of this tensor in MultiHeadAttentionTypeAndShapeInference. But CPU EP does throw exception if mask tensor rank > 2.

So I think we should also update the contrib ops doc and may be create a bug for the CPU EP?

@PatriceVignola
Copy link
Contributor Author

PatriceVignola commented Oct 16, 2023

The implementation lgtm. Thank you for adding this.

The contrib op schema for MultiheadAttention says that it only accepts rank <= 2 tensor for mask as mentioned here:

"Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)",

At the same time, I also see that there is no validation on the rank of this tensor in MultiHeadAttentionTypeAndShapeInference. But CPU EP does throw exception if mask tensor rank > 2.

So I think we should also update the contrib ops doc and may be create a bug for the CPU EP?

The contrib op definitions are being updated in another branch in parallel (including tests), but it will all line up for 1.16.2.

Note: I don't own the other branch (and it's a giant feature branch with hundreds of modified files), which is why I made this separate PR.

@sumitsays
Copy link
Contributor

The implementation lgtm. Thank you for adding this.
The contrib op schema for MultiheadAttention says that it only accepts rank <= 2 tensor for mask as mentioned here:

"Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)",

At the same time, I also see that there is no validation on the rank of this tensor in MultiHeadAttentionTypeAndShapeInference. But CPU EP does throw exception if mask tensor rank > 2.
So I think we should also update the contrib ops doc and may be create a bug for the CPU EP?

The contrib op definitions are being updated in another branch in parallel (including tests), but it will all line up for 1.16.2.

Note: I don't own the other branch (and it's a giant feature branch with hundreds of modified files), which is why I made this separate PR.

Understood. Thanks again!

@PatriceVignola PatriceVignola merged commit 6557538 into main Oct 18, 2023
@PatriceVignola PatriceVignola deleted the user/pavignol/allow-more-mha-masks branch October 18, 2023 00:31
jchen351 pushed a commit that referenced this pull request Oct 18, 2023
Those masks are used for MHA in LLaMA.
PatriceVignola added a commit that referenced this pull request Oct 26, 2023
Those masks are used for MHA in LLaMA.
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
Those masks are used for MHA in LLaMA.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants