-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Add new SDPA API to jax.nn #21371
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,10 +16,11 @@ | |
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Sequence | ||
from functools import partial | ||
import operator | ||
import numpy as np | ||
from typing import Any | ||
from typing import Any, Literal | ||
import warnings | ||
|
||
import jax | ||
|
@@ -31,6 +32,8 @@ | |
from jax._src import dtypes | ||
from jax._src import util | ||
from jax._src.core import AxisName | ||
from jax._src.cudnn.fused_attention_stablehlo import ( | ||
dot_product_attention as cudnn_dot_product_attention, MaskType) | ||
from jax._src.numpy import util as numpy_util | ||
from jax._src.typing import Array, ArrayLike | ||
from jax._src.ops.special import logsumexp as _logsumexp | ||
|
@@ -765,3 +768,159 @@ def hard_silu(x: ArrayLike) -> Array: | |
return x_arr * hard_sigmoid(x_arr) | ||
|
||
hard_swish = hard_silu | ||
|
||
def _get_large_negative(dtype): | ||
dtype_max = jnp.finfo(dtype).max | ||
return jnp.asarray(-0.7 * dtype_max, dtype=dtype) | ||
|
||
def _get_causal_mask(T, S, dtype): | ||
pred = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) | ||
mask = jnp.where(pred, jnp.asarray(0.0, dtype), _get_large_negative(dtype)) | ||
return mask[jnp.newaxis, jnp.newaxis, :, :] | ||
|
||
def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): | ||
logits_dtype = jnp.promote_types(query.dtype, jnp.float32) | ||
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, | ||
preferred_element_type=logits_dtype) | ||
|
||
logits *= jnp.array(scale, dtype=logits.dtype) | ||
|
||
if bias is not None: | ||
logits = (logits + bias).astype(logits.dtype) | ||
|
||
if mask is not None: | ||
assert mask.dtype == jnp.bool_ | ||
large_negative_number = _get_large_negative(logits.dtype) | ||
padded_logits = jnp.where(mask, logits, large_negative_number) | ||
else: | ||
padded_logits = logits | ||
|
||
if is_causal: | ||
T, S = query.shape[-3], key.shape[-3] | ||
mask = _get_causal_mask(T, S, logits.dtype) | ||
padded_logits = padded_logits + mask | ||
|
||
# Softmax and it is always carried out in fp32. | ||
padded_logits = padded_logits.astype(jnp.float32) | ||
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype) | ||
|
||
encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value) | ||
return encoded | ||
|
||
def dot_product_attention( | ||
query: ArrayLike, | ||
key: ArrayLike, | ||
value: ArrayLike, | ||
*, | ||
bias: ArrayLike | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets make bias and mask mandatory keyword args. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only made mask to be mandatory keyword arg. Because if the bias is mandatory keyword arg, I found the vjp(sdpa, q, k, v, bias=bias) would complain and then I tried to move bias=bias to partial() but that would affect how I get the dbias. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What complaint do you get? You can always use a lambda in vjp tests to turn it into positional args for testing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thanks for the advice. Done. |
||
mask: ArrayLike | None = None, | ||
scale: float | None = None, | ||
is_causal: bool = False, | ||
implementation: Literal['xla', 'cudnn'] | None = None) -> Array: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we can go with a shorter parameter name here, e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use 'backend' in JAX for GPU/CPU/TPU. I do think implementation is reasonable. And shouldn't abbreviations be avoided? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, just noticed the new comment. Changed it back to |
||
r"""Scaled dot product attention function. | ||
|
||
Computes the attention function on Query, Key, and Value tensors: | ||
|
||
.. math :: | ||
\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}}V) | ||
|
||
If we define :code:`logits` as the output of :math:`QK^T` and the | ||
:code:`probs` as the output of :math:`softmax`. | ||
|
||
Throughout this function, we utilize the following uppercase letters to | ||
represent the shape of array: | ||
|
||
B = batch size | ||
S = length of the key/value (source) | ||
T = length of the query (target) | ||
N = number of attention heads | ||
H = dimensions of each attention head | ||
|
||
Args: | ||
query: query array; shape :code:`(BTNH)` | ||
key: key array; shape :code:`(BSNH)` | ||
value: value array; shape :code:`(BSNH)` | ||
bias: optional, bias array to be added to logits; shape broadcastable to | ||
:code:`(BNTS)`. | ||
mask: optional, mask array used to filter out logits. It is a boolean mask | ||
where `True` indicates the element should take part in attention. For an | ||
additive mask, users should pass it to `bias`. The shape is broadcastable | ||
to :code:`(BNTS)`. | ||
scale: scale for the logits. If None, the scale will be set to 1 divided by | ||
the square root of query's head dimension (i.e. H). | ||
is_causal: If true, causal attention will be applied. Note, some | ||
implementations like `xla` will generate a mask tensor and apply it to the | ||
logits, but other implementations like `cudnn` will avoid computing the | ||
unmasked regions. | ||
implementaion: A string to control which implementation backend to use. | ||
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults | ||
to `None`, which will automatically select the best available backend. | ||
Note, `cudnn` supports only a subset of shapes/dtypes, and an exception | ||
will be thrown if its not supported. | ||
|
||
Returns: | ||
An array of the attention output with the same shape as :code:`query`. | ||
""" | ||
def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: | ||
if t.ndim != len(shape): | ||
raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}") | ||
value_str1 = f't.shape={t.shape}' | ||
value_str2 = f'shape={shape}' | ||
for i in range(t.ndim): | ||
if shape[i] != -1 and t.shape[i] != shape[i]: | ||
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}") | ||
|
||
query = jnp.asarray(query) | ||
key = jnp.asarray(key) | ||
value = jnp.asarray(value) | ||
bias = bias if bias is None else jnp.asarray(bias) | ||
mask = mask if mask is None else jnp.asarray(mask) | ||
|
||
B, S, N, H = key.shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
_check_has_shape(value, [B, S, N, H], 'value') | ||
_check_has_shape(query, [B, -1, N, H], 'query') | ||
T = query.shape[1] | ||
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale | ||
if not (query.dtype == key.dtype == value.dtype): | ||
raise ValueError(f"query/key/value should have the same shape, but got " | ||
f"{query.shape} vs {key.shape} vs {value.shape}.") | ||
if mask is not None and mask.dtype != jnp.bool_: | ||
raise ValueError(f"Mask must be boolean dtype, but got {mask.dtype}.") | ||
|
||
match implementation: | ||
case 'xla': | ||
return _dot_product_attention_xla( | ||
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, | ||
) | ||
case 'cudnn': | ||
mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK | ||
# Convert bool mask to float mask for addition | ||
if mask is not None: | ||
large_negative_number = _get_large_negative(query.dtype) | ||
mask = jnp.where(mask, jnp.zeros((), query.dtype), | ||
large_negative_number) | ||
|
||
# Prepare the bias for cudnn flash attention: | ||
# We should never use the mask argument of cudnn, because it is | ||
# multiplicative and thus the masked values (i.e. the zeros) will | ||
# still take part in the following softmax. So, we need to use the bias | ||
# argument for the mask to ensure the masked values are very small. | ||
# TODO(kaixih@nvidia): The logic should be moved to the internal of | ||
# cudnn_dot_product_attention. | ||
if bias is None: | ||
bias = mask | ||
elif mask is not None: | ||
bias = bias + mask | ||
|
||
return cudnn_dot_product_attention( | ||
query, key, value, bias, mask=None, scale=scale_val, | ||
mask_type=mask_type, | ||
) | ||
case None: | ||
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select | ||
# best backend. | ||
return _dot_product_attention_xla( | ||
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, | ||
) | ||
case _: | ||
raise ValueError(f"Unsupported implementation option: {implementation}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,8 @@ | |
from jax._src import core | ||
from jax._src import test_util as jtu | ||
from jax._src import ad_checkpoint | ||
from jax._src.interpreters import mlir | ||
from jax._src.lib import cuda_versions | ||
from jax.test_util import check_grads | ||
from jax import nn | ||
from jax import random | ||
|
@@ -36,8 +38,106 @@ | |
|
||
config.parse_flags_with_absl() | ||
|
||
def _is_required_cudnn_version_satisfied(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about something like
? Note also that you can remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
return ( | ||
jtu.is_cuda_compute_capability_at_least("8.0") and | ||
cuda_versions is not None and | ||
cuda_versions.cudnn_get_version() >= 8904 | ||
) | ||
|
||
def _get_causal_mask(T, S): | ||
causal_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) | ||
return causal_mask[jnp.newaxis, jnp.newaxis, :, :] | ||
|
||
@jtu.with_config(jax_legacy_prng_key="allow", | ||
jax_numpy_dtype_promotion="standard") | ||
class NNFunctionsTest(jtu.JaxTestCase): | ||
@parameterized.product( | ||
dtype=[jnp.float32, jnp.bfloat16, jnp.float16], | ||
use_bias=(False, True), | ||
causal_mode=(None, 'is_causal', 'is_mask'), | ||
impl=('xla', 'cudnn'), | ||
) | ||
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, impl): | ||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): | ||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") | ||
if impl == 'cudnn' and dtype == jnp.float32: | ||
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") | ||
|
||
sdpa = nn.dot_product_attention | ||
B, S, T, N, H = 2, 128, 128, 4, 32 | ||
keys = random.split(random.PRNGKey(0), 4) | ||
Q = random.normal(keys[0], (B, T, N, H), dtype) | ||
K = random.normal(keys[1], (B, S, N, H), dtype) | ||
V = random.normal(keys[2], (B, S, N, H), dtype) | ||
if use_bias: | ||
bias = random.normal(keys[3], (1, N, T, S), dtype) | ||
else: | ||
bias = None | ||
|
||
is_causal = causal_mode == 'is_causal' | ||
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None | ||
|
||
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) | ||
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) | ||
|
||
if impl == 'cudnn': | ||
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias=bias, mask=causal_mask) | ||
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) | ||
self.assertIn('__cudnn$fmha', hlo) | ||
|
||
out_ref = sdpa_ref(Q, K, V, bias=bias, mask=causal_mask) | ||
out_ans = sdpa_ans(Q, K, V, bias=bias, mask=causal_mask) | ||
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) | ||
|
||
@parameterized.product( | ||
dtype=[jnp.float32, jnp.bfloat16, jnp.float16], | ||
use_bias=[False, True], | ||
causal_mode=[None, 'is_causal', 'is_mask'], | ||
impl=['xla', 'cudnn'], | ||
) | ||
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, impl): | ||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): | ||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") | ||
if impl == 'cudnn' and dtype == jnp.float32: | ||
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") | ||
|
||
sdpa = nn.dot_product_attention | ||
B, S, T, N, H = 2, 128, 128, 4, 32 | ||
keys = random.split(random.PRNGKey(0), 5) | ||
Q = random.normal(keys[0], (B, T, N, H), dtype) | ||
K = random.normal(keys[1], (B, S, N, H), dtype) | ||
V = random.normal(keys[2], (B, S, N, H), dtype) | ||
grad = random.normal(keys[3], (B, T, N, H), dtype) | ||
if use_bias: | ||
bias = random.normal(keys[4], (1, N, T, S), dtype) | ||
else: | ||
bias = None | ||
|
||
is_causal = causal_mode == 'is_causal' | ||
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A very important test is to verify that if the user states cudnn as the implementation, this is actually used. I think we should check the HLO, which you can get like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) | ||
fn_ref = lambda q, k, v, b, m: sdpa_ref(q, k, v, bias=b, mask=m) | ||
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K, V, bias, causal_mask) | ||
dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad) | ||
|
||
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) | ||
fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m) | ||
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask) | ||
dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad) | ||
|
||
if impl == 'cudnn': | ||
lowered = jax.jit(sdpa_vjp_ans).lower(grad) | ||
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) | ||
self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(') | ||
|
||
rtol, atol = (.01, .01) | ||
self.assertAllClose(dQ_ref, dQ_ans, rtol=rtol, atol=atol) | ||
self.assertAllClose(dK_ref, dK_ans, rtol=rtol, atol=atol) | ||
self.assertAllClose(dV_ref, dV_ans, rtol=rtol, atol=atol) | ||
self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03) | ||
|
||
@jtu.skip_on_flag("jax_skip_slow_tests", True) | ||
def testSoftplusGrad(self): | ||
check_grads(nn.softplus, (1e-8,), order=4, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment explaining the magic number here? Why is -0.7 used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Emm, actually, I am not sure about the rational of using -0.7. But it seems to be a rule of thumb used in various places: praxis,
jax paged attention, jax splash attention, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot find explanation from these places.