Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Support custom dtype convert in jax.nn.dot_product_attention #24352

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,11 +843,51 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
padded_logits = jnp.where(combined_mask, logits, large_negative_number)
return padded_logits

# We use this custom dot_general in the QK einsum op of attention to match
# dtypes used in the Flash Attention implementation. For bf16 inputs as an
# example, the fprop is like:
# bf16 -> dot -> fp32
# Then the bprop is like:
# (1) Without this change:
# fp32 -> dot -> fp32 -> cvt -> bf16.
# (2) With this change:
# fp32 -> cvt -> bf16 -> dot -> bf16.
@partial(custom_jvp, nondiff_argnums=(2, 3, 4))
def dot_general_with_custom_convert(
lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None
):
return lax.dot_general(
lhs, rhs, dimension_numbers, precision=precision,
preferred_element_type=preferred_element_type,
)

@dot_general_with_custom_convert.defjvp
def dot_general_with_custom_convert_jvp(
dimension_numbers, precision, preferred_element_type, primals, tangents
):
lhs, rhs = primals
lhs_dot, rhs_dot = tangents

out = lax.dot_general(
lhs, rhs, dimension_numbers, precision=precision,
preferred_element_type=preferred_element_type,
)

grad_out = lax.dot_general(
lhs_dot, rhs, dimension_numbers, precision=precision,
) + lax.dot_general(
lhs, rhs_dot, dimension_numbers, precision=precision,
)
grad_out = grad_out.astype(preferred_element_type)
Comment on lines +876 to +881
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you talk through the argument here in a bit more detail? This doesn't seem like a great idea, because it's not clear to me why we would want to accumulate the tangents using a different dtype than the primals. I see why this ends up giving the correct dtypes on the backwards pass, but it seems bad for the numerics of fwd mode.

Perhaps there's something I'm missing here, but if not, we'd probably want to use custom_vjp instead of custom_jvp because then it's clear that the fwd pass isn't appropriate.

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason we use jvp instead of vjp is that we don't need to compute the new dimension_number for the bprop dots of grad_x and grad_w.

And I agree the code is a bit hacky and may be that clear on its purpose. I will try your dot algorithm since it works now as you mentioned.

Will update the thread later. Thx.


return out, grad_out

def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
scale, q_seqlen, kv_seqlen, local_window_size):
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
preferred_element_type=logits_dtype)
preferred_element_type=logits_dtype,
_dot_general=dot_general_with_custom_convert)

logits *= jnp.array(scale, dtype=logits.dtype)

Expand Down
28 changes: 22 additions & 6 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import collections
from functools import partial
import itertools
import re
import unittest

from absl.testing import absltest
Expand Down Expand Up @@ -99,8 +100,8 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl):

self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
self.assertAllClose(dK_ref, dK_ans, rtol=.01, atol=.01)
self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01)

@parameterized.product(
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
Expand Down Expand Up @@ -164,10 +165,10 @@ def testDotProductAttentionMask(self, mask_mode):
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))

self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=.02, atol=.02)
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02)
self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03)
self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01)
self.assertAllClose(dbias_ref, dbias_ans, rtol=.02, atol=.02)

@parameterized.product(
batch_size=[1, 16],
Expand Down Expand Up @@ -224,7 +225,22 @@ def bwd_ans(x, bias, mask):
else:
_, dbias_ref, _ = bwd_ref(x, bias, mask)
_, dbias_ans, _ = bwd_ans(x, bias, mask)
self.assertAllClose(dbias_ans, dbias_ref, rtol=.03, atol=.03)
self.assertAllClose(dbias_ans, dbias_ref, rtol=.02, atol=.02)

def testDotProductAttentionCustomDtype(self):
dtype = jnp.bfloat16
B, S, N, H = 4, 128, 4, 32
keys = random.split(random.PRNGKey(0), 2)
x = random.normal(keys[0], (B, S, N, H), dtype)

def attention(x):
return jax.nn.dot_product_attention(x, x, x, implementation='xla')
_, f_vjp = jax.vjp(attention, x)
jitted = jax.jit(f_vjp)

hlo = jitted.lower(x).as_text("hlo")
dot_count = len(re.findall(r"dot.*? = bf16.*?", hlo))
self.assertEqual(4, dot_count)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
Expand Down