-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Add new SDPA API to jax.nn #21371
Conversation
@hawkinsp Can you help find reviewers? |
Pushed a new commit to remove the use of |
@sharadmv Any updates? |
The API should have an
Regarding the names: does cuDNN expose both FlashAttention and non-FlashAttention? Perhaps this should be Are there any configuration options a user might want to pass to the cuDNN implementation? If so, it could be a string or a cuDNN config dataclass. Eg. in the low-memory XLA case, the chunk size is something a user might want to configure. |
jax/_src/nn/functions.py
Outdated
warnings.warn(f"The flash attention cannot be used because: {e}") | ||
|
||
# Compute the attention logits | ||
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the transformer engine docs (https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.flax.DotProductAttention) for the JAX cuDNN fused attention API, it states "float32_logits (bool, default = False) – Whether to compute attention logits in float32 for the unfused attention backend. For fused attention backend, the accumulation is always float32 without the perf overhead." It would be good to match numerics as closely as possible. What exactly is cuDNN doing here? Is it accumulating in FP32 (say for bf16 inputs) and then keeping it in FP32 to compute softmax? Or does it cast the accumulation to BF16 like this does?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we accumulate in fp32 and do the softmax directly on it. We don't do a cast_to BF16 and then cast_back_to FP32 for softmax in flash attention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK so you are using quite different numerics for BF16 in your XLA and cuDNN implementations. The XLA-equivalent-here is to use preferred_element_type=jnp.float32
in jnp.einsum
. This has the negative consequence of using much more memory. But having inconsistent numerics might mean that if you train with cuDNN, inference will no longer work due to the downcasted numerics. Is there an option for controlling this in cuDNN, so we can do the XLA-default downcast behaviour by default and users can switch using preferred_element_type=jnp.float32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cudnn uses a different algorithm for the attention (i.e. the flash attention algorithm). So, from it alone, the numerics would be very different from XLA. As for the dot compute/accumulation dtype, I don't think the current cudnn allow a config for that.
@Cjkkkk Can you comment on this? Basically, I think google's concern is how to best match the numerics from the default other implementations. Or do you know if we have tried things like train model with one attention but inference with another?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlashAttention is numerically identical to XLA attention, modulo doing ops in different orders, which is common in ML. Here we are discussing downcasting the logits to BF16 (your XLA implementation) versus not downcasting to BF16 (cuDNN). We should match numerics as much as possible so that people can easily mix and match implementations. There are two options:
- Make the cuDNN numerics the default, and match this in XLA (so use
preferred_element_type=jnp.float32
in the einsum). Comes at the cost of larger default XLA memory usage and slower speed (more memory to move). Allow users to opt-out of this with an option (can add this later). - Make downcasting the default. Need option then to use cuDNN attention.
I think I prefer the first option. Either way, should document this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see your point now. I think we did something similar in Transformer Engine and some arg like https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/softmax.py#L274. So, we can add a similar one in the API and set it to true by default. Then it's like:
xla:
# softmax_in_fp32=True (default)
preferred_element_type = fp32 if softmax_in_fp32 else None
# first gemm
jnp.einsum(..., preferred_element_type=preferred_element_type)
cudnn:
if softmax_in_fp32:
raise some error.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the softmax in flash attention (online softmax) is fundamentally different algorithms, mathematically should be equivalent to JAX softmax but due to floating point computation, it is not numerically the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuDNN attention is doing accumulation in fp32 and no downcasting to bf16 before softmax so the first option would align both attentions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the flag softmax_in_fp32
and defaulted it to True
.
jax/_src/nn/functions.py
Outdated
*, | ||
scale: float | None = None, | ||
logits_cap: float | None = None, | ||
mask_fn: Callable[[ArrayLike, ArrayLike], ArrayLike] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an application for the mask function beyond causal masking? This is such a common use-case that I think it should have its own option (is_causal
. PyTorch has this as well https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). Its ugly to have to pass mask_fn=jax.nn.SdpaCausalMask()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there are more masking types as defined here. Although we can use a separate is_causal for the common case, but does they mean we need to provide knob options to the API for any new mask types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Masking is one way to implement causal attention, but good implementations avoid implementing it in this way. As it is conceptually distinct from masking, this should be reflected in the API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the other masking types: I notice they are not supported in your current API yet. I think we should discuss these in a separate PR. But we should get rid of masking_fn for now.
jax/_src/nn/functions.py
Outdated
mask = mask[jnp.newaxis, jnp.newaxis, :, :] | ||
return mask | ||
|
||
class SdpaCausalMask: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name is non-ideal, better to avoid this confusing acronym (ScaledDotProductAttentionCausalMask
is clearer).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
_pv_dot_general: Callable[..., ArrayLike] = lax.dot_general) -> Array: | ||
r"""Scaled dot product attention function. | ||
|
||
Computes the attention function on Quary, Key, and Value tensors: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quary -> Query.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
if not (mask is None or mask_fn is None): | ||
raise ValueError("Cannot provide both mask and mask_fn") | ||
|
||
# Try if the flash attention can be used. If failed, fall back to the default |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users of this function will very likely want to use an optimized FlashAttention kernel. Yet its extremely opaque to the user exactly what they must do to actually use a fast kernel (what shapes? what API surface of this function are they allowed to use?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is a valid concern. The reason we use this try-except block is the fast attn support surface is too complex to show that here and if you check here and search "raise", the restrictions are scattered here and there on the shapes/sharding/cuda/cudnn versions/etc. We would like these restrictions stay on that file because they are specific to cudnn and they might keep changing. But to your question, we try to provide users a readable error string on why flash attn is not used at runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, you are not throwing an error when flashattention is not used. I think its easy to miss error messages that are printed/logged somewhere. This is related to the idea of having an implementation option, where users can explicitly ask for the cudnn implementation and fail if its not supported (with good errors).
jax/_src/nn/functions.py
Outdated
:code:`cap * jnp.tanh(logits / cap)`. | ||
mask_fn: custom masking function takes in :code:`logits` and :code:`mask` | ||
and outputs masked :code:`logits`. | ||
softmax_fn: custom softmax function takes in :code:`logits` and outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the motivation for this generality? Fast kernels will never be used when you have such custom functions, and this function provides little benefit over writing this yourself in JAX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The motivation is also from the praxis impl like here. Also, I think the design rationale is to make the public API support surface as general as possible. And yes, the fast attn support surface is just a subset of the API at this moment, but we keep trying to expand that gradually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be as conservative as possible with features for this JAX function. Easy to add this if needed, but hard to take away. And as mentioned, it provides almost not benefit to users right now (its very easy to implement an XLA attention with custom softmax if a user wants this). So lets remove everything not needed for hooking up cuDNN attention.
jax/_src/nn/functions.py
Outdated
value: value array; shape :code:`(BSNH)` | ||
bias: optional, bias array to be added to logits; shape broadcastable to | ||
:code:`(BNTS)`. | ||
mask: optional, mask array used to filter out logits. We use additive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing. Why not just have a boolean mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In attention, we typically use additive masks to make sure the undesired values have a very small value (negative inf) so that the following softmax generate 0s for them. For example, in praxis like here. Also, just noticed that pytorch API supports both bool mask (multiplicative mask) and float mask (additive mask). So, do you want to to support both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a boolean mask allows for optimizations (less data to move around), and allows this function to handle the complexity zeroing out correctly for all dtypes. So if we have an additive fbias + boolean mask, that seems like all cases. So lets do that.
jax/_src/nn/functions.py
Outdated
@dataclass | ||
class SdpaPhiloxDropout: | ||
rate: float = 0.0 | ||
seed: int = 123 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This obviously doesn't compose well with the JAX RNG system as it doesn't take a JAX random seed. Could it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The jax random seed is also an int, like here. My understanding is that cudnn uses a different prng algo of "philox" but jax uses "threefry". So, if users choose to use cudnn flash attn + dropout, I would like them to explicitly use this xxxPhiloxDropout
to be aware of the possible numeric difference. I don't think jax supports philox. But maybe I missed your point. Can you clarify on how to "compose well with the JAX RNG system"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant: JAX has an API for splitting PRNG keys (you don't work with raw integer seeds directly except seeding the very first jnp.PRNGKey
). All JAX random functions take either an int or a jnp.PRNG
object), so this should also I think.
jax/_src/nn/functions.py
Outdated
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, | ||
_dot_general=_qk_dot_general) | ||
# Bias | ||
if bias is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would separate out the xla implementation into _dot_product_attention_xla
or something and call it in the user-facing function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean put all the following code into a new function _dot_product_attention_xla
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
@@ -31,6 +32,8 @@ | |||
from jax._src import dtypes | |||
from jax._src import util | |||
from jax._src.core import AxisName | |||
from jax._src.cudnn.fused_attention_stablehlo import ( | |||
dot_product_attention, MaskType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import dot_product_attention
as cudnn_dot_product_attention
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tests/nn_test.py
Outdated
sdpa = nn.scaled_dot_product_attention | ||
B, S, T, N, H = 4, 128, 128, 8, 32 | ||
keys = random.split(random.PRNGKey(0), 4) | ||
Q = random.normal(keys[0], (B, T, N, H), dtype=jnp.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need tests for more dtypes. For example for FP32: is cudnn flashattention using TF32 by default to match the JAX implementation? (not sure whether this is even supported)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cudnn flash attn only supports fp16 and bf16 (see here). So we focus on these two cases at this moment.
Sorry, I think I missed this comment. Do you mean sth like: def sdpa(..., implementation=None):
if implementation == 'cudnn':
cudnn_sdpa() # users expect to fail on error
elif implementation == 'pallas':
pallas_sdpa() # this is for the future.
elif implementation is None:
# current path of try-except. and will always fall back to `_dot_product_attention_xla`. Re cudnn flash attentions: |
That looks correct. We have two options here:
There are pros and cons of each, and some tricky questions. For example:
|
I think the name should be |
jax/_src/nn/functions.py
Outdated
T = query.shape[1] | ||
scale_val = (1.0 / np.sqrt(H) if scale is None else scale) | ||
|
||
if not (mask is None or mask_fn is None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This further shows the limitations of the causal attention as mask API: this excludes the user using fast causal attention + having a custom array mask (eg to mask per-example padding).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this padding + causal, we would like users to use our MaskType.PADDING_CAUSAL
mode, which needs two additional arguments of there.
So, I am imagining like the CausalMask, we also provide PaddingCausalMask
(not in this PR, but a followup PR probabaly)
class PaddingCausalMask:
q_seqlen: Array = None
kv_seqlen: Array = None
Then, we can extract these arrays and pass them to the cudnn_sdpa().
In other words, to use cudnn sdpa, we don't recommend the mixed static mask + runtime mask, although we could do that. I think I added the quoted mask is None or mask_fn is None
from pytorch impl to be more conservative. But if you prefer supporting such mixed usage. I can make a change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think query and key sequence lengths should be arguments to the main function. Again, padding here is an implementation detail. I assume cuDNN can get speedups knowing the sequence lengths by avoiding doing unnecessary computation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think query and key sequence lengths should be arguments to the main function
With this and the is_causal argument, you can call the appropriate cuDNN padding APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Padding will be added in the next PRs.
jax/_src/nn/functions.py
Outdated
return x + mask | ||
|
||
@dataclass | ||
class ScaledDotProductAttentionPhiloxDropout: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the sharding behaviour of this dropout? I'm worried that it won't compose well with JAX sharding APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be explicit: consider the data parallel case. Applying this function on a single device should have the same result as applying it sharded over two. This obviously causes issues for the RNG, as per-device, the random tensors are now different shapes. Which keys to use? JAX handles this, but its a problem when using built-in custom-op RNG like in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropout will be added in the next PRs and we will specify the parallel support for it there. Thanks for bring up this concern.
As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features. |
Just pushed some new commits for the simplified sdpa. @sbodenstein PTAL. Also talked to @Cjkkkk and he will try to implement the combination of bias and mask in the cudnn dot_product_attention API (as described here in (1)). When that is in, our logic of preparing bias will be much simpler. |
jax/_src/nn/functions.py
Outdated
only supports post-bias scale. | ||
is_causal: If true, assumes upper left causal attention masking. | ||
implementation: A string to control which implementation to use. Supported | ||
strings are `xla` (default), `cudnn` (cuDNN flash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'xla' is not the default, the default is None which allows this function to choose the best implementation. Add documentation for None as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the None
for now. I plan to add None in the next PR. Basically, for the "choose the best implementation", do you mean the try-except block like I did previously:
...
elif implementation is None:
try:
return cudnn_dot_product_attention()
except e:
warning
return _dot_product_attention_xla()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. The assumption is that flashattention will be better whenever it is available. This seems like good default behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. We are waiting for this PR to be merged and then we don't need to duplicate the code of mask+bias preparation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added None implementation but mainly as a placeholder.
jax/_src/nn/functions.py
Outdated
(e.g. :code:`-0.5 * jnp.finfo(dtype).max`) represents `False`. The | ||
shape is broadcastable to :code:`(BNTS)`. | ||
scale: scale for the logits. If None, the scale will be set to 1 divided by | ||
the square root of query size. Note, the currently implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: query size is not a defined term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
value: value array; shape :code:`(BSNH)` | ||
bias: optional, bias array to be added to logits; shape broadcastable to | ||
:code:`(BNTS)`. | ||
mask: optional, mask array used to filter out logits. Two types of masks are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does the float mask buy us here? Do we need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you do the same with the bias term, which pytorch doesn't have?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But also: what does this give you that you can't do with the boolean mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is any question: its easier to add features than remove features. Can always add this to a later PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I think it's a good point. We can let user to pass anything to bias
that needs to be added to the logits. So, the float mask should be passed into bias
, or, if users have the relative position bias, they should do the addition of relative_bias+float_mask
and pass it to the bias
.
Then the mask
is only for the boolean mask as you suggested. This simplified our design. Do I get it correct? I will update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
jax/_src/nn/functions.py
Outdated
else: | ||
raise ValueError(f"Unsupported implementation option: {implementation}") | ||
|
||
return encoded, probs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not match the type signature of the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
attention). | ||
|
||
Returns: | ||
An array with the same shape of :code:`query`. If flash attention is not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output type signature should not depend on the implementation
. Why do we need to return probs
? If this becomes necessary, should probably control this with an argument and fail if implementation='cudnn'
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. The new change makes the API only returns the attention output. Added a new return_probs
parameter and it will return probs
when it is True.
jax/_src/nn/functions.py
Outdated
shape is broadcastable to :code:`(BNTS)`. | ||
scale: scale for the logits. If None, the scale will be set to 1 divided by | ||
the square root of query size. Note, the currently implementation | ||
only supports post-bias scale. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use the flax convention, and scale before bias. Note users can always work around this if they need to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Now the behavior is scale*QK+bias+mask
.
jax/_src/nn/functions.py
Outdated
scale: scale for the logits. If None, the scale will be set to 1 divided by | ||
the square root of query size. Note, the currently implementation | ||
only supports post-bias scale. | ||
is_causal: If true, assumes upper left causal attention masking. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would remove any mention of masking. It is also useful info that cudnn will avoid computing the unmasked regions, so this option will give speedups over using mask
to achieve the same thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. masking is removed. New comments focus on describing the causal behavior.
jax/_src/nn/functions.py
Outdated
scale: scale for the logits. If None, the scale will be set to 1 divided by | ||
the square root of query's head dimension (i.e. H). | ||
is_causal: If true, causal attention will be applied. Note, some | ||
implemntations like `xla` will generate a mask tensor and apply |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implemntations -> implementations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
jax/_src/nn/functions.py
Outdated
scale: float | None = None, | ||
is_causal: bool = False, | ||
implementation: str = 'xla', | ||
return_probs: bool = False) -> Union[Array, tuple[Array, Array], str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this return a string? Also prefer Array | tuple[Array, Array]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the raise statement in the else
count? I am not sure. But when searching the jax code base, it seems some place added this str when the raise exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. That is not a return. We could add a Raises section in the docs (https://google.github.io/styleguide/pyguide.html#244-decision).
jax/_src/nn/functions.py
Outdated
not return_probs | ||
), "Implementation `cudnn` doesn't support return_probs=True." | ||
|
||
if implementation == 'xla': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use match/case here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Pushed a few more changes. PTAL. @sbodenstein |
jax/_src/nn/functions.py
Outdated
scale: float | None = None, | ||
is_causal: bool = False, | ||
implementation: str | None = 'xla', | ||
softmax_in_fp32: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we always compute softmax in fp32 in XLA for bf16 inputs. I prefer the TransformerEngine float32_logits
option. One ugliness: this is not true for FP64 inputs. One option is to have a logitits_dtype
that defaults to None
, which means FP32 for BF16/FP16 inputs, and FP64 for FP64 inputs. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we drop this option for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropped this option. We can add this choice when the fast implementation supports it.
Pushed new commits to resolved some failed python lint tests. Btw, can we have the access to add |
Please squash the commits and we can merge. |
Done. PTAL. @superbobry |
I still saw this lint error: |
No worries, I'll resolve this internally. |
This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in #21371. PiperOrigin-RevId: 650183480
This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in #21371. PiperOrigin-RevId: 650183480
This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in #21371. PiperOrigin-RevId: 650183480
This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in jax-ml/jax#21371. PiperOrigin-RevId: 650183480
This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in jax-ml/jax#21371. PiperOrigin-RevId: 650183480
This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in #21371. PiperOrigin-RevId: 650201550
This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in jax-ml/jax#21371. PiperOrigin-RevId: 650201550
This is necessary to avoid a circular dependency jax -> fused_attention_stablehlo -> experimental -> jax in jax-ml/jax#21371. PiperOrigin-RevId: 650201550
Thanks for adding FA! Is there a timeline to add |
Yes, this is on our radar to be implemented. Can we know what types of model you are working on that needs the dropout? |
Attention dropout would help for almost all low-data training regimes. Detection Transformers are one well-known example. Torch supports FA dropout (possibly non-deterministic) in their functional API. |
Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.
This PR proposes introducing a new API in the
jax.nn
module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.cc. @nluehr @Cjkkkk @cliffwoolley