Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Support custom dtype convert in jax.nn.dot_product_attention #24352

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Oct 16, 2024

Addressing the issue brought up in 24047.

This PR does this:

# We use this custom dot_general in the QK einsum op of attention to match
# dtypes used in the Flash Attention implementation. For bf16 inputs as an
# example, the fprop is like:
#   bf16 -> dot -> fp32
# Then the bprop is like:
# (1) Without this change:
#   fp32 -> dot -> fp32 -> cvt -> bf16.
# (2) With this change:
#   fp32 -> cvt -> bf16 -> dot -> bf16.

In addition, we adjust the atol/rtol a bit.

cc. @sbodenstein @superbobry

@sbodenstein
Copy link
Contributor

I think that this is best fixed using the new precision API, rather than custom JVP. What is the motivation for this approach, other than being able to land it a little bit faster?

@kaixih
Copy link
Contributor Author

kaixih commented Oct 21, 2024

Right, the main motivation is to get it implemented faster. I'm okay with using the new precision API. Do you know when it will be available, especially with the PJRT plugin mentioned here? If it’s coming soon, we can close this PR. If not, do you think we should keep this one open and migrate to the new method later? Also, pinging @dfm for comments.

Copy link
Collaborator

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping! I think this is a reasonable approach to consider, but I would also prefer to fix this using DotAlgorithm like @sbodenstein. I believe that the DotAlgorithm approach would fix the the issue that I've highlighted inline out of the box. I'll check in to see what timeline we can expect for the PJRT fixes. Thanks @kaixih!

Comment on lines +876 to +881
grad_out = lax.dot_general(
lhs_dot, rhs, dimension_numbers, precision=precision,
) + lax.dot_general(
lhs, rhs_dot, dimension_numbers, precision=precision,
)
grad_out = grad_out.astype(preferred_element_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you talk through the argument here in a bit more detail? This doesn't seem like a great idea, because it's not clear to me why we would want to accumulate the tangents using a different dtype than the primals. I see why this ends up giving the correct dtypes on the backwards pass, but it seems bad for the numerics of fwd mode.

Perhaps there's something I'm missing here, but if not, we'd probably want to use custom_vjp instead of custom_jvp because then it's clear that the fwd pass isn't appropriate.

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason we use jvp instead of vjp is that we don't need to compute the new dimension_number for the bprop dots of grad_x and grad_w.

And I agree the code is a bit hacky and may be that clear on its purpose. I will try your dot algorithm since it works now as you mentioned.

Will update the thread later. Thx.

@dfm
Copy link
Collaborator

dfm commented Oct 23, 2024

An update! With the release of JAX v0.4.35 using dot algorithms as an argument for lax.dot_general now works (#24480)! So perhaps we can try to fix #24047 using that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants