Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Add new SDPA API to jax.nn #21371

Closed
wants to merge 1 commit into from
Closed

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented May 22, 2024

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the jax.nn module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

@kaixih
Copy link
Contributor Author

kaixih commented May 22, 2024

@hawkinsp Can you help find reviewers?

@hawkinsp hawkinsp requested a review from sharadmv May 29, 2024 17:29
@kaixih
Copy link
Contributor Author

kaixih commented May 31, 2024

Pushed a new commit to remove the use of is_training for the cudnn flash attention. This is a followup of this merged PR.

@kaixih
Copy link
Contributor Author

kaixih commented Jun 4, 2024

@sharadmv Any updates?

@sbodenstein
Copy link
Contributor

sbodenstein commented Jun 7, 2024

The API should have an implementation option, taking values like "xla", "cudnn", and None (the default, which selects the best algorithm). This list will grow with alternative kernel implementations (Pallas, etc). It is important to be able to select the implementation type:

  • "cudnn" will fail immediately if there is some unsupported shape, which prevents silent reversions to slow code paths.
  • Generating serialized models to do inference with on a different device type (eg train on GPU and test on TPU).

Regarding the names: does cuDNN expose both FlashAttention and non-FlashAttention? Perhaps this should be "cudnn_flash"? Note that XLA also has different implementations: we could support the low-memory chunked implementation given here (https://arxiv.org/abs/2112.05682) that inspired FlashAttention, and which is closer numerically to FlashAttention than standard attention and has the same memory complexity (maybe "xla_chunked"? "xla_low_memory"?).

Are there any configuration options a user might want to pass to the cuDNN implementation? If so, it could be a string or a cuDNN config dataclass. Eg. in the low-memory XLA case, the chunk size is something a user might want to configure.

warnings.warn(f"The flash attention cannot be used because: {e}")

# Compute the attention logits
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the transformer engine docs (https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.flax.DotProductAttention) for the JAX cuDNN fused attention API, it states "float32_logits (bool, default = False) – Whether to compute attention logits in float32 for the unfused attention backend. For fused attention backend, the accumulation is always float32 without the perf overhead." It would be good to match numerics as closely as possible. What exactly is cuDNN doing here? Is it accumulating in FP32 (say for bf16 inputs) and then keeping it in FP32 to compute softmax? Or does it cast the accumulation to BF16 like this does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we accumulate in fp32 and do the softmax directly on it. We don't do a cast_to BF16 and then cast_back_to FP32 for softmax in flash attention.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so you are using quite different numerics for BF16 in your XLA and cuDNN implementations. The XLA-equivalent-here is to use preferred_element_type=jnp.float32 in jnp.einsum. This has the negative consequence of using much more memory. But having inconsistent numerics might mean that if you train with cuDNN, inference will no longer work due to the downcasted numerics. Is there an option for controlling this in cuDNN, so we can do the XLA-default downcast behaviour by default and users can switch using preferred_element_type=jnp.float32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cudnn uses a different algorithm for the attention (i.e. the flash attention algorithm). So, from it alone, the numerics would be very different from XLA. As for the dot compute/accumulation dtype, I don't think the current cudnn allow a config for that.

@Cjkkkk Can you comment on this? Basically, I think google's concern is how to best match the numerics from the default other implementations. Or do you know if we have tried things like train model with one attention but inference with another?

Copy link
Contributor

@sbodenstein sbodenstein Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlashAttention is numerically identical to XLA attention, modulo doing ops in different orders, which is common in ML. Here we are discussing downcasting the logits to BF16 (your XLA implementation) versus not downcasting to BF16 (cuDNN). We should match numerics as much as possible so that people can easily mix and match implementations. There are two options:

  • Make the cuDNN numerics the default, and match this in XLA (so use preferred_element_type=jnp.float32 in the einsum). Comes at the cost of larger default XLA memory usage and slower speed (more memory to move). Allow users to opt-out of this with an option (can add this later).
  • Make downcasting the default. Need option then to use cuDNN attention.

I think I prefer the first option. Either way, should document this.

Copy link
Contributor Author

@kaixih kaixih Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see your point now. I think we did something similar in Transformer Engine and some arg like https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/softmax.py#L274. So, we can add a similar one in the API and set it to true by default. Then it's like:

xla:
# softmax_in_fp32=True (default)
preferred_element_type = fp32 if softmax_in_fp32 else None
# first gemm
jnp.einsum(..., preferred_element_type=preferred_element_type)

cudnn:
if softmax_in_fp32:
  raise some error.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the softmax in flash attention (online softmax) is fundamentally different algorithms, mathematically should be equivalent to JAX softmax but due to floating point computation, it is not numerically the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuDNN attention is doing accumulation in fp32 and no downcasting to bf16 before softmax so the first option would align both attentions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the flag softmax_in_fp32 and defaulted it to True.

*,
scale: float | None = None,
logits_cap: float | None = None,
mask_fn: Callable[[ArrayLike, ArrayLike], ArrayLike] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an application for the mask function beyond causal masking? This is such a common use-case that I think it should have its own option (is_causal. PyTorch has this as well https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). Its ugly to have to pass mask_fn=jax.nn.SdpaCausalMask().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are more masking types as defined here. Although we can use a separate is_causal for the common case, but does they mean we need to provide knob options to the API for any new mask types?

Copy link
Contributor

@sbodenstein sbodenstein Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Masking is one way to implement causal attention, but good implementations avoid implementing it in this way. As it is conceptually distinct from masking, this should be reflected in the API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the other masking types: I notice they are not supported in your current API yet. I think we should discuss these in a separate PR. But we should get rid of masking_fn for now.

mask = mask[jnp.newaxis, jnp.newaxis, :, :]
return mask

class SdpaCausalMask:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is non-ideal, better to avoid this confusing acronym (ScaledDotProductAttentionCausalMask is clearer).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

_pv_dot_general: Callable[..., ArrayLike] = lax.dot_general) -> Array:
r"""Scaled dot product attention function.

Computes the attention function on Quary, Key, and Value tensors:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quary -> Query.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if not (mask is None or mask_fn is None):
raise ValueError("Cannot provide both mask and mask_fn")

# Try if the flash attention can be used. If failed, fall back to the default
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users of this function will very likely want to use an optimized FlashAttention kernel. Yet its extremely opaque to the user exactly what they must do to actually use a fast kernel (what shapes? what API surface of this function are they allowed to use?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is a valid concern. The reason we use this try-except block is the fast attn support surface is too complex to show that here and if you check here and search "raise", the restrictions are scattered here and there on the shapes/sharding/cuda/cudnn versions/etc. We would like these restrictions stay on that file because they are specific to cudnn and they might keep changing. But to your question, we try to provide users a readable error string on why flash attn is not used at runtime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, you are not throwing an error when flashattention is not used. I think its easy to miss error messages that are printed/logged somewhere. This is related to the idea of having an implementation option, where users can explicitly ask for the cudnn implementation and fail if its not supported (with good errors).

:code:`cap * jnp.tanh(logits / cap)`.
mask_fn: custom masking function takes in :code:`logits` and :code:`mask`
and outputs masked :code:`logits`.
softmax_fn: custom softmax function takes in :code:`logits` and outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the motivation for this generality? Fast kernels will never be used when you have such custom functions, and this function provides little benefit over writing this yourself in JAX.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation is also from the praxis impl like here. Also, I think the design rationale is to make the public API support surface as general as possible. And yes, the fast attn support surface is just a subset of the API at this moment, but we keep trying to expand that gradually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be as conservative as possible with features for this JAX function. Easy to add this if needed, but hard to take away. And as mentioned, it provides almost not benefit to users right now (its very easy to implement an XLA attention with custom softmax if a user wants this). So lets remove everything not needed for hooking up cuDNN attention.

value: value array; shape :code:`(BSNH)`
bias: optional, bias array to be added to logits; shape broadcastable to
:code:`(BNTS)`.
mask: optional, mask array used to filter out logits. We use additive
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing. Why not just have a boolean mask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In attention, we typically use additive masks to make sure the undesired values have a very small value (negative inf) so that the following softmax generate 0s for them. For example, in praxis like here. Also, just noticed that pytorch API supports both bool mask (multiplicative mask) and float mask (additive mask). So, do you want to to support both?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a boolean mask allows for optimizations (less data to move around), and allows this function to handle the complexity zeroing out correctly for all dtypes. So if we have an additive fbias + boolean mask, that seems like all cases. So lets do that.

@dataclass
class SdpaPhiloxDropout:
rate: float = 0.0
seed: int = 123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This obviously doesn't compose well with the JAX RNG system as it doesn't take a JAX random seed. Could it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The jax random seed is also an int, like here. My understanding is that cudnn uses a different prng algo of "philox" but jax uses "threefry". So, if users choose to use cudnn flash attn + dropout, I would like them to explicitly use this xxxPhiloxDropout to be aware of the possible numeric difference. I don't think jax supports philox. But maybe I missed your point. Can you clarify on how to "compose well with the JAX RNG system"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant: JAX has an API for splitting PRNG keys (you don't work with raw integer seeds directly except seeding the very first jnp.PRNGKey). All JAX random functions take either an int or a jnp.PRNG object), so this should also I think.

logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
_dot_general=_qk_dot_general)
# Bias
if bias is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would separate out the xla implementation into _dot_product_attention_xla or something and call it in the user-facing function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean put all the following code into a new function _dot_product_attention_xla?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -31,6 +32,8 @@
from jax._src import dtypes
from jax._src import util
from jax._src.core import AxisName
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention, MaskType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import dot_product_attention as cudnn_dot_product_attention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

tests/nn_test.py Outdated
sdpa = nn.scaled_dot_product_attention
B, S, T, N, H = 4, 128, 128, 8, 32
keys = random.split(random.PRNGKey(0), 4)
Q = random.normal(keys[0], (B, T, N, H), dtype=jnp.bfloat16)
Copy link
Contributor

@sbodenstein sbodenstein Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need tests for more dtypes. For example for FP32: is cudnn flashattention using TF32 by default to match the JAX implementation? (not sure whether this is even supported)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cudnn flash attn only supports fp16 and bf16 (see here). So we focus on these two cases at this moment.

@kaixih
Copy link
Contributor Author

kaixih commented Jun 12, 2024

The API should have an implementation option, taking values like "xla", "cudnn", and None (the default, which selects the best algorithm). This list will grow with alternative kernel implementations (Pallas, etc). It is important to be able to select the implementation type:

  • "cudnn" will fail immediately if there is some unsupported shape, which prevents silent reversions to slow code paths.
  • Generating serialized models to do inference with on a different device type (eg train on GPU and test on TPU).

Regarding the names: does cuDNN expose both FlashAttention and non-FlashAttention? Perhaps this should be "cudnn_flash"? Note that XLA also has different implementations: we could support the low-memory chunked implementation given here (https://arxiv.org/abs/2112.05682) that inspired FlashAttention, and which is closer numerically to FlashAttention than standard attention and has the same memory complexity (maybe "xla_chunked"? "xla_low_memory"?).

Are there any configuration options a user might want to pass to the cuDNN implementation? If so, it could be a string or a cuDNN config dataclass. Eg. in the low-memory XLA case, the chunk size is something a user might want to configure.

Sorry, I think I missed this comment. Do you mean sth like:

def sdpa(..., implementation=None):
  if implementation == 'cudnn':
    cudnn_sdpa() # users expect to fail on error
  elif implementation == 'pallas':
    pallas_sdpa() # this is for the future.
  elif implementation is None:
    # current path of try-except. and will always fall back to `_dot_product_attention_xla`.

Re cudnn flash attentions:
(1) cuDNN used to expose both flash and non-flash attention kernel, but we choose not to use the non-flash anymore. So, the cudnn attention means cudnn flash attention now. And I am ok with the cudnn.
(2) We don't need to pass config to cudnn calls and we are trying to hide it from users.

@sbodenstein
Copy link
Contributor

sbodenstein commented Jun 13, 2024

Sorry, I think I missed this comment. Do you mean sth like:

That looks correct. We have two options here:

  1. Have multiple SDPA functions, one per backend/implementation.
  2. Have a single API with the implementation option.

There are pros and cons of each, and some tricky questions. For example:

  • How closely do numerics need to match in the super-function to be considered 'the same'? As found in this review, cuDNN with bf16 inputs does not cast the first matmul to BF16 before doing softmax, whilst XLA does. If we choose the cuDNN convention, the XLA implementation will be incredibly memory-inefficient. This might be a significant difference in certain applications (eg. training with one but doing inference with the other on a different device-type). With future Pallas kernels, we can match the numerics. But this might be harder for third-party libraries like cuDNN. We might also do autotuning and choose the best kernel with the None option, which becomes problematic with these numerical differences. This is an argument to have separate functions for third-party kernels that JAX has no control over and are largely opaque (hard to see what numerical choices are being made), and only have a super-function for implementations under JAX-control.
  • Another argument for separate functions is that the API can be restricted to only the supported features, rather than the most general function imaginable. The current design is makes it hard for users to see what is supported, and limits documentation opportunities. In addition, there are cuDNN specific options (like the philox dropout) unsupported by any other backend, further complicating the API.

@sbodenstein
Copy link
Contributor

I think the name should be dot_product_attention rather than scaled_dot_product_attention. Its also more consistent with Flax naming (https://flax.readthedocs.io/en/v0.8.0/api_reference/flax.linen/_autosummary/flax.linen.dot_product_attention.html).

T = query.shape[1]
scale_val = (1.0 / np.sqrt(H) if scale is None else scale)

if not (mask is None or mask_fn is None):
Copy link
Contributor

@sbodenstein sbodenstein Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This further shows the limitations of the causal attention as mask API: this excludes the user using fast causal attention + having a custom array mask (eg to mask per-example padding).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this padding + causal, we would like users to use our MaskType.PADDING_CAUSAL mode, which needs two additional arguments of there.

So, I am imagining like the CausalMask, we also provide PaddingCausalMask (not in this PR, but a followup PR probabaly)

class PaddingCausalMask:
  q_seqlen: Array = None
  kv_seqlen: Array = None

Then, we can extract these arrays and pass them to the cudnn_sdpa().

In other words, to use cudnn sdpa, we don't recommend the mixed static mask + runtime mask, although we could do that. I think I added the quoted mask is None or mask_fn is None from pytorch impl to be more conservative. But if you prefer supporting such mixed usage. I can make a change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think query and key sequence lengths should be arguments to the main function. Again, padding here is an implementation detail. I assume cuDNN can get speedups knowing the sequence lengths by avoiding doing unnecessary computation?

Copy link
Contributor

@sbodenstein sbodenstein Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think query and key sequence lengths should be arguments to the main function

With this and the is_causal argument, you can call the appropriate cuDNN padding APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding will be added in the next PRs.

return x + mask

@dataclass
class ScaledDotProductAttentionPhiloxDropout:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the sharding behaviour of this dropout? I'm worried that it won't compose well with JAX sharding APIs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be explicit: consider the data parallel case. Applying this function on a single device should have the same result as applying it sharded over two. This obviously causes issues for the RNG, as per-device, the random tensors are now different shapes. Which keys to use? JAX handles this, but its a problem when using built-in custom-op RNG like in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropout will be added in the next PRs and we will specify the parallel support for it there. Thanks for bring up this concern.

@sbodenstein
Copy link
Contributor

As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.

@kaixih
Copy link
Contributor Author

kaixih commented Jun 25, 2024

Just pushed some new commits for the simplified sdpa. @sbodenstein PTAL.

Also talked to @Cjkkkk and he will try to implement the combination of bias and mask in the cudnn dot_product_attention API (as described here in (1)). When that is in, our logic of preparing bias will be much simpler.

only supports post-bias scale.
is_causal: If true, assumes upper left causal attention masking.
implementation: A string to control which implementation to use. Supported
strings are `xla` (default), `cudnn` (cuDNN flash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'xla' is not the default, the default is None which allows this function to choose the best implementation. Add documentation for None as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the None for now. I plan to add None in the next PR. Basically, for the "choose the best implementation", do you mean the try-except block like I did previously:

...
elif implementation is None:
  try:
    return cudnn_dot_product_attention()
  except e:
    warning
  return _dot_product_attention_xla()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. The assumption is that flashattention will be better whenever it is available. This seems like good default behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We are waiting for this PR to be merged and then we don't need to duplicate the code of mask+bias preparation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added None implementation but mainly as a placeholder.

(e.g. :code:`-0.5 * jnp.finfo(dtype).max`) represents `False`. The
shape is broadcastable to :code:`(BNTS)`.
scale: scale for the logits. If None, the scale will be set to 1 divided by
the square root of query size. Note, the currently implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: query size is not a defined term.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

value: value array; shape :code:`(BSNH)`
bias: optional, bias array to be added to logits; shape broadcastable to
:code:`(BNTS)`.
mask: optional, mask array used to filter out logits. Two types of masks are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the float mask buy us here? Do we need it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a general practice to use float masks to do the additive masking. Like in praxis, the mask is already using floats of large negative values (here). Also, the pytorch sdpa (here) also supports both. So, I think maybe we should do the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you do the same with the bias term, which pytorch doesn't have?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But also: what does this give you that you can't do with the boolean mask?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is any question: its easier to add features than remove features. Can always add this to a later PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think it's a good point. We can let user to pass anything to bias that needs to be added to the logits. So, the float mask should be passed into bias, or, if users have the relative position bias, they should do the addition of relative_bias+float_mask and pass it to the bias.

Then the mask is only for the boolean mask as you suggested. This simplified our design. Do I get it correct? I will update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

else:
raise ValueError(f"Unsupported implementation option: {implementation}")

return encoded, probs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not match the type signature of the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

attention).

Returns:
An array with the same shape of :code:`query`. If flash attention is not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output type signature should not depend on the implementation. Why do we need to return probs? If this becomes necessary, should probably control this with an argument and fail if implementation='cudnn'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. The new change makes the API only returns the attention output. Added a new return_probs parameter and it will return probs when it is True.

shape is broadcastable to :code:`(BNTS)`.
scale: scale for the logits. If None, the scale will be set to 1 divided by
the square root of query size. Note, the currently implementation
only supports post-bias scale.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use the flax convention, and scale before bias. Note users can always work around this if they need to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Now the behavior is scale*QK+bias+mask.

scale: scale for the logits. If None, the scale will be set to 1 divided by
the square root of query size. Note, the currently implementation
only supports post-bias scale.
is_causal: If true, assumes upper left causal attention masking.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove any mention of masking. It is also useful info that cudnn will avoid computing the unmasked regions, so this option will give speedups over using mask to achieve the same thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. masking is removed. New comments focus on describing the causal behavior.

scale: scale for the logits. If None, the scale will be set to 1 divided by
the square root of query's head dimension (i.e. H).
is_causal: If true, causal attention will be applied. Note, some
implemntations like `xla` will generate a mask tensor and apply
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemntations -> implementations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

scale: float | None = None,
is_causal: bool = False,
implementation: str = 'xla',
return_probs: bool = False) -> Union[Array, tuple[Array, Array], str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this return a string? Also prefer Array | tuple[Array, Array]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the raise statement in the else count? I am not sure. But when searching the jax code base, it seems some place added this str when the raise exists.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. That is not a return. We could add a Raises section in the docs (https://google.github.io/styleguide/pyguide.html#244-decision).

not return_probs
), "Implementation `cudnn` doesn't support return_probs=True."

if implementation == 'xla':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use match/case here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@kaixih
Copy link
Contributor Author

kaixih commented Jun 28, 2024

Pushed a few more changes. PTAL. @sbodenstein

scale: float | None = None,
is_causal: bool = False,
implementation: str | None = 'xla',
softmax_in_fp32: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we always compute softmax in fp32 in XLA for bf16 inputs. I prefer the TransformerEngine float32_logits option. One ugliness: this is not true for FP64 inputs. One option is to have a logitits_dtype that defaults to None, which means FP32 for BF16/FP16 inputs, and FP64 for FP64 inputs. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we drop this option for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped this option. We can add this choice when the fast implementation supports it.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 5, 2024
@kaixih
Copy link
Contributor Author

kaixih commented Jul 5, 2024

Pushed new commits to resolved some failed python lint tests. Btw, can we have the access to add kokoro:force-run label to trigger the tests?

@superbobry
Copy link
Collaborator

Please squash the commits and we can merge.

@kaixih
Copy link
Contributor Author

kaixih commented Jul 7, 2024

Done. PTAL. @superbobry

@kaixih
Copy link
Contributor Author

kaixih commented Jul 7, 2024

I still saw this lint error: jax/_src/nn/functions.py:924: error: Argument 4 to "dot_product_attention" has incompatible type "Array | ndarray[Any, Any] | bool_ | number[Any] | int | float | complex | None"; expected "Array | None" [arg-type] But I am a bit confused. I think it refers to the mask which I have already converted to Array by jnp.asarray(mask) at the beginning in the function. Do you have any advice on this? @superbobry @sbodenstein

@superbobry
Copy link
Collaborator

No worries, I'll resolve this internally.

copybara-service bot pushed a commit that referenced this pull request Jul 8, 2024
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in #21371.

PiperOrigin-RevId: 650183480
copybara-service bot pushed a commit that referenced this pull request Jul 8, 2024
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in #21371.

PiperOrigin-RevId: 650183480
copybara-service bot pushed a commit that referenced this pull request Jul 8, 2024
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in #21371.

PiperOrigin-RevId: 650183480
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Jul 8, 2024
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in jax-ml/jax#21371.

PiperOrigin-RevId: 650183480
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 8, 2024
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in jax-ml/jax#21371.

PiperOrigin-RevId: 650183480
copybara-service bot pushed a commit that referenced this pull request Jul 8, 2024
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in #21371.

PiperOrigin-RevId: 650201550
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Jul 8, 2024
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in jax-ml/jax#21371.

PiperOrigin-RevId: 650201550
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 8, 2024
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in jax-ml/jax#21371.

PiperOrigin-RevId: 650201550
@copybara-service copybara-service bot closed this in df6080f Jul 8, 2024
@MasterSkepticista
Copy link

As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.

Thanks for adding FA! Is there a timeline to add dropout support in the SDPA API? I understand it is on hold due to differences in PRNG implementation. Would it be OK if we expose dropout_rate to the API while warning the user on reproducibility if cudnn is selected?

https://github.com/google/jax/blob/417fcd574b9f33410ea8eb78ffdea825ad343eee/jax/_src/cudnn/fused_attention_stablehlo.py#L954-L956

@kaixih
Copy link
Contributor Author

kaixih commented Aug 27, 2024

As discussed offline: lets land the simplest version first, without dropout or other complications. Then progressively add features.

Thanks for adding FA! Is there a timeline to add dropout support in the SDPA API? I understand it is on hold due to differences in PRNG implementation. Would it be OK if we expose dropout_rate to the API while warning the user on reproducibility if cudnn is selected?

https://github.com/google/jax/blob/417fcd574b9f33410ea8eb78ffdea825ad343eee/jax/_src/cudnn/fused_attention_stablehlo.py#L954-L956

Yes, this is on our radar to be implemented. Can we know what types of model you are working on that needs the dropout?

@MasterSkepticista
Copy link

Yes, this is on our radar to be implemented. Can we know what types of model you are working on that needs the dropout?

Attention dropout would help for almost all low-data training regimes. Detection Transformers are one well-known example.

Torch supports FA dropout (possibly non-deterministic) in their functional API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants