-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Support BNTH input formats for the fused attention #20380
Conversation
6915fb6
to
25d0f6b
Compare
Just rebased the change to resolve some conflicts. Also tried to minimize the changes. |
@kaixih Thanks for the PR, LGTM. |
@Cjkkkk Do you know who was the reviewer for such changes to your previous attn changes? Maybe we can ping him/her? |
Peter just did the work :) |
7baaeec
to
18e3383
Compare
BTNH = 0 | ||
BNTH = 1 | ||
|
||
def _normalize_layout(layout_str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just layout
? You can use a type annotation to document the type:
def _normalize_layout(layout: str) -> AttentionLayout:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is better. Fixed.
and (not is_training or q_seq_len % 64 == 0 and kv_seq_len % 64 == 0): | ||
def check_qkv_layout(query, key, value, layout): | ||
def assert_eq(a, b, c, msg): | ||
assert a == b == c, msg + f' must be same: {a}, {b}, {c}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see that the old version also used assert
for argument validation, but since you are changing the validation logic slightly, I would recommend using raise
instead. For example:
if q_rank != 4:
raise ValueError(f"Q must have a ran of 4, got {q_rank}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Initially, it determines the attention weights by processing Q and K, | ||
subsequently combining the outcomes using K. Throughout this function, we | ||
utilize the following uppercase letters to represent specific parameters of | ||
JTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's a JTensor? Perhaps you meant "array"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Can you squash the commits, please? |
1a6866f
to
0489eee
Compare
Done. Also removed the trailing spaces pointed out by the failed lint tests. |
Now, it seems all tests pass. |
@superbobry Hi, any updates on this? |
For enhanced performance and API flexibility, we've extended the functionality of
dot_production_attention
to accommodate QKV inputs in the BNTH format.cc. @Cjkkkk