Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Support BNTH input formats for the fused attention #20380

Merged
merged 1 commit into from
Apr 7, 2024

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Mar 21, 2024

For enhanced performance and API flexibility, we've extended the functionality of dot_production_attention to accommodate QKV inputs in the BNTH format.

cc. @Cjkkkk

@kaixih kaixih force-pushed the cudnn_attention_dev branch from 6915fb6 to 25d0f6b Compare March 26, 2024 20:55
@kaixih
Copy link
Contributor Author

kaixih commented Mar 26, 2024

Just rebased the change to resolve some conflicts. Also tried to minimize the changes.

@Cjkkkk
Copy link
Contributor

Cjkkkk commented Mar 27, 2024

@kaixih Thanks for the PR, LGTM.

@kaixih
Copy link
Contributor Author

kaixih commented Mar 27, 2024

@Cjkkkk Do you know who was the reviewer for such changes to your previous attn changes? Maybe we can ping him/her?

@hawkinsp hawkinsp requested a review from superbobry March 27, 2024 21:05
@Cjkkkk
Copy link
Contributor

Cjkkkk commented Mar 27, 2024

@Cjkkkk Do you know who was the reviewer for such changes to your previous attn changes? Maybe we can ping him/her?

Peter just did the work :)

@kaixih kaixih force-pushed the cudnn_attention_dev branch 2 times, most recently from 7baaeec to 18e3383 Compare April 1, 2024 19:49
BTNH = 0
BNTH = 1

def _normalize_layout(layout_str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just layout? You can use a type annotation to document the type:

def _normalize_layout(layout: str) -> AttentionLayout:
  ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is better. Fixed.

and (not is_training or q_seq_len % 64 == 0 and kv_seq_len % 64 == 0):
def check_qkv_layout(query, key, value, layout):
def assert_eq(a, b, c, msg):
assert a == b == c, msg + f' must be same: {a}, {b}, {c}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that the old version also used assert for argument validation, but since you are changing the validation logic slightly, I would recommend using raise instead. For example:

if q_rank != 4:
  raise ValueError(f"Q must have a ran of 4, got {q_rank}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Initially, it determines the attention weights by processing Q and K,
subsequently combining the outcomes using K. Throughout this function, we
utilize the following uppercase letters to represent specific parameters of
JTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's a JTensor? Perhaps you meant "array"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 2, 2024
@superbobry
Copy link
Collaborator

Can you squash the commits, please?

@kaixih kaixih force-pushed the cudnn_attention_dev branch from 1a6866f to 0489eee Compare April 3, 2024 20:48
@kaixih
Copy link
Contributor Author

kaixih commented Apr 3, 2024

Can you squash the commits, please?

Done. Also removed the trailing spaces pointed out by the failed lint tests.

@kaixih
Copy link
Contributor Author

kaixih commented Apr 3, 2024

Now, it seems all tests pass.
@superbobry PTAL.

@Cjkkkk
Copy link
Contributor

Cjkkkk commented Apr 5, 2024

@superbobry Hi, any updates on this?

@copybara-service copybara-service bot merged commit 9a931af into jax-ml:main Apr 7, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants