Skip to content

Commit

Permalink
Remove is_training
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed May 31, 2024
1 parent a62f183 commit 7da7d9c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
5 changes: 2 additions & 3 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,11 +918,10 @@ def _assert_has_shape(t: ArrayLike, shape: Sequence[int]) -> None:
elif mask is not None:
bias = mask

# TODO(kaixih@nvidia): We set is_training to True for now. This argument
# will be removed later and then we should remove it as well.
encoded = dot_product_attention(
query, key, value, bias, None, scale=scale_val, mask_type=mask_type,
seed=seed, dropout_rate=rate, is_training=True)
seed=seed, dropout_rate=rate
)
return encoded, None
else:
warnings.warn("The flash attention cannot be used because unsupported"
Expand Down
17 changes: 12 additions & 5 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def testSdpa(self, use_bias):
bias = None

# Use a dummy lambda for mask_fn to disable flash attention.
out_ref, _ = sdpa(Q, K, V, bias, mask_fn=lambda x, mask: x)
out, _ = sdpa(Q, K, V, bias)
out_ref, probs_ref = sdpa(Q, K, V, bias, mask_fn=lambda x, mask: x)
out, probs = sdpa(Q, K, V, bias)
self.assertIsNotNone(probs_ref)
self.assertIsNone(probs)
self.assertAllClose(out_ref, out)

@parameterized.parameters(False, True)
Expand Down Expand Up @@ -115,13 +117,18 @@ def testSdpaMask(self, use_bias):
# For the reference, use the causal mask explicitly and a dummy lambda for
# dropout_fn to disable flash attention.
causal_mask = _causal_mask(T, jnp.bfloat16)
out_ref, _ = sdpa(Q, K, V, bias, causal_mask, dropout_fn=lambda x: x)
out_ref, probs_ref = sdpa(
Q, K, V, bias, causal_mask, dropout_fn=lambda x: x
)
self.assertIsNotNone(probs_ref)

# Test runtime generated causal mask
out, _ = sdpa(Q, K, V, bias, mask_fn=nn.SdpaCausalMask())
out, probs = sdpa(Q, K, V, bias, mask_fn=nn.SdpaCausalMask())
self.assertIsNone(probs)
self.assertAllClose(out_ref, out, atol=atol)
# Test user-given causal mask
out, _ = sdpa(Q, K, V, bias, causal_mask)
out, probs = sdpa(Q, K, V, bias, causal_mask)
self.assertIsNone(probs)
self.assertAllClose(out_ref, out, atol=atol)

@parameterized.parameters(False, True)
Expand Down

0 comments on commit 7da7d9c

Please sign in to comment.