Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Add new SDPA API to jax.nn #21371

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 160 additions & 1 deletion jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

from __future__ import annotations

from collections.abc import Sequence
from functools import partial
import operator
import numpy as np
from typing import Any
from typing import Any, Literal
import warnings

import jax
Expand All @@ -31,6 +32,8 @@
from jax._src import dtypes
from jax._src import util
from jax._src.core import AxisName
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention as cudnn_dot_product_attention, MaskType)
from jax._src.numpy import util as numpy_util
from jax._src.typing import Array, ArrayLike
from jax._src.ops.special import logsumexp as _logsumexp
Expand Down Expand Up @@ -765,3 +768,159 @@ def hard_silu(x: ArrayLike) -> Array:
return x_arr * hard_sigmoid(x_arr)

hard_swish = hard_silu

def _get_large_negative(dtype):
dtype_max = jnp.finfo(dtype).max
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment explaining the magic number here? Why is -0.7 used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, actually, I am not sure about the rational of using -0.7. But it seems to be a rule of thumb used in various places: praxis,
jax paged attention, jax splash attention, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot find explanation from these places.


def _get_causal_mask(T, S, dtype):
pred = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
mask = jnp.where(pred, jnp.asarray(0.0, dtype), _get_large_negative(dtype))
return mask[jnp.newaxis, jnp.newaxis, :, :]

def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
preferred_element_type=logits_dtype)

logits *= jnp.array(scale, dtype=logits.dtype)

if bias is not None:
logits = (logits + bias).astype(logits.dtype)

if mask is not None:
assert mask.dtype == jnp.bool_
large_negative_number = _get_large_negative(logits.dtype)
padded_logits = jnp.where(mask, logits, large_negative_number)
else:
padded_logits = logits

if is_causal:
T, S = query.shape[-3], key.shape[-3]
mask = _get_causal_mask(T, S, logits.dtype)
padded_logits = padded_logits + mask

# Softmax and it is always carried out in fp32.
padded_logits = padded_logits.astype(jnp.float32)
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype)

encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value)
return encoded

def dot_product_attention(
query: ArrayLike,
key: ArrayLike,
value: ArrayLike,
*,
bias: ArrayLike | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets make bias and mask mandatory keyword args.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only made mask to be mandatory keyword arg. Because if the bias is mandatory keyword arg, I found the vjp(sdpa, q, k, v, bias=bias) would complain and then I tried to move bias=bias to partial() but that would affect how I get the dbias.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What complaint do you get? You can always use a lambda in vjp tests to turn it into positional args for testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for the advice. Done.

mask: ArrayLike | None = None,
scale: float | None = None,
is_causal: bool = False,
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can go with a shorter parameter name here, e.g. impl or backend?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use 'backend' in JAX for GPU/CPU/TPU. I do think implementation is reasonable. And shouldn't abbreviations be avoided?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, just noticed the new comment. Changed it back to implementation.

r"""Scaled dot product attention function.

Computes the attention function on Query, Key, and Value tensors:

.. math ::
\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}}V)

If we define :code:`logits` as the output of :math:`QK^T` and the
:code:`probs` as the output of :math:`softmax`.

Throughout this function, we utilize the following uppercase letters to
represent the shape of array:

B = batch size
S = length of the key/value (source)
T = length of the query (target)
N = number of attention heads
H = dimensions of each attention head

Args:
query: query array; shape :code:`(BTNH)`
key: key array; shape :code:`(BSNH)`
value: value array; shape :code:`(BSNH)`
bias: optional, bias array to be added to logits; shape broadcastable to
:code:`(BNTS)`.
mask: optional, mask array used to filter out logits. It is a boolean mask
where `True` indicates the element should take part in attention. For an
additive mask, users should pass it to `bias`. The shape is broadcastable
to :code:`(BNTS)`.
scale: scale for the logits. If None, the scale will be set to 1 divided by
the square root of query's head dimension (i.e. H).
is_causal: If true, causal attention will be applied. Note, some
implementations like `xla` will generate a mask tensor and apply it to the
logits, but other implementations like `cudnn` will avoid computing the
unmasked regions.
implementaion: A string to control which implementation backend to use.
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults
to `None`, which will automatically select the best available backend.
Note, `cudnn` supports only a subset of shapes/dtypes, and an exception
will be thrown if its not supported.

Returns:
An array of the attention output with the same shape as :code:`query`.
"""
def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None:
if t.ndim != len(shape):
raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}")
value_str1 = f't.shape={t.shape}'
value_str2 = f'shape={shape}'
for i in range(t.ndim):
if shape[i] != -1 and t.shape[i] != shape[i]:
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}")

query = jnp.asarray(query)
key = jnp.asarray(key)
value = jnp.asarray(value)
bias = bias if bias is None else jnp.asarray(bias)
mask = mask if mask is None else jnp.asarray(mask)

B, S, N, H = key.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to call asarray() on key and other arguments before accessing .shape, because they are annotated to have type ArrayLike which e.g. includes Python scalars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

_check_has_shape(value, [B, S, N, H], 'value')
_check_has_shape(query, [B, -1, N, H], 'query')
T = query.shape[1]
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale
if not (query.dtype == key.dtype == value.dtype):
raise ValueError(f"query/key/value should have the same shape, but got "
f"{query.shape} vs {key.shape} vs {value.shape}.")
if mask is not None and mask.dtype != jnp.bool_:
raise ValueError(f"Mask must be boolean dtype, but got {mask.dtype}.")

match implementation:
case 'xla':
return _dot_product_attention_xla(
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val,
)
case 'cudnn':
mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK
# Convert bool mask to float mask for addition
if mask is not None:
large_negative_number = _get_large_negative(query.dtype)
mask = jnp.where(mask, jnp.zeros((), query.dtype),
large_negative_number)

# Prepare the bias for cudnn flash attention:
# We should never use the mask argument of cudnn, because it is
# multiplicative and thus the masked values (i.e. the zeros) will
# still take part in the following softmax. So, we need to use the bias
# argument for the mask to ensure the masked values are very small.
# TODO(kaixih@nvidia): The logic should be moved to the internal of
# cudnn_dot_product_attention.
if bias is None:
bias = mask
elif mask is not None:
bias = bias + mask

return cudnn_dot_product_attention(
query, key, value, bias, mask=None, scale=scale_val,
mask_type=mask_type,
)
case None:
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
# best backend.
return _dot_product_attention_xla(
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val,
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")
1 change: 1 addition & 0 deletions jax/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
one_hot as one_hot,
relu as relu,
relu6 as relu6,
dot_product_attention as dot_product_attention,
selu as selu,
sigmoid as sigmoid,
soft_sign as soft_sign,
Expand Down
100 changes: 100 additions & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from jax._src import core
from jax._src import test_util as jtu
from jax._src import ad_checkpoint
from jax._src.interpreters import mlir
from jax._src.lib import cuda_versions
from jax.test_util import check_grads
from jax import nn
from jax import random
Expand All @@ -36,8 +38,106 @@

config.parse_flags_with_absl()

def _is_required_cudnn_version_satisfied():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like

return (
    jtu.test_device_matches(["cuda"]) and
    jtu.is_cuda_compute_capability_at_least("8.0") and
    cuda_version is not None and
    cuda_versions.cudnn_get_version() < 8904
)

?

Note also that you can remove jtu.test_device_matches(["cuda"]), because jtu.is_cuda_compute_capability_at_least already checks that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return (
jtu.is_cuda_compute_capability_at_least("8.0") and
cuda_versions is not None and
cuda_versions.cudnn_get_version() >= 8904
)

def _get_causal_mask(T, S):
causal_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
return causal_mask[jnp.newaxis, jnp.newaxis, :, :]

@jtu.with_config(jax_legacy_prng_key="allow",
jax_numpy_dtype_promotion="standard")
class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=(False, True),
causal_mode=(None, 'is_causal', 'is_mask'),
impl=('xla', 'cudnn'),
)
def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")

sdpa = nn.dot_product_attention
B, S, T, N, H = 2, 128, 128, 4, 32
keys = random.split(random.PRNGKey(0), 4)
Q = random.normal(keys[0], (B, T, N, H), dtype)
K = random.normal(keys[1], (B, S, N, H), dtype)
V = random.normal(keys[2], (B, S, N, H), dtype)
if use_bias:
bias = random.normal(keys[3], (1, N, T, S), dtype)
else:
bias = None

is_causal = causal_mode == 'is_causal'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None

sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)

if impl == 'cudnn':
lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias=bias, mask=causal_mask)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertIn('__cudnn$fmha', hlo)

out_ref = sdpa_ref(Q, K, V, bias=bias, mask=causal_mask)
out_ans = sdpa_ans(Q, K, V, bias=bias, mask=causal_mask)
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)

@parameterized.product(
dtype=[jnp.float32, jnp.bfloat16, jnp.float16],
use_bias=[False, True],
causal_mode=[None, 'is_causal', 'is_mask'],
impl=['xla', 'cudnn'],
)
def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied():
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")

sdpa = nn.dot_product_attention
B, S, T, N, H = 2, 128, 128, 4, 32
keys = random.split(random.PRNGKey(0), 5)
Q = random.normal(keys[0], (B, T, N, H), dtype)
K = random.normal(keys[1], (B, S, N, H), dtype)
V = random.normal(keys[2], (B, S, N, H), dtype)
grad = random.normal(keys[3], (B, T, N, H), dtype)
if use_bias:
bias = random.normal(keys[4], (1, N, T, S), dtype)
else:
bias = None

is_causal = causal_mode == 'is_causal'
causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very important test is to verify that if the user states cudnn as the implementation, this is actually used. I think we should check the HLO, which you can get like:

jax.jit(f).lower(args).compiler_ir('stablehlo')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
fn_ref = lambda q, k, v, b, m: sdpa_ref(q, k, v, bias=b, mask=m)
_, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K, V, bias, causal_mask)
dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad)

sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl)
fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m)
_, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask)
dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad)

if impl == 'cudnn':
lowered = jax.jit(sdpa_vjp_ans).lower(grad)
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(')

rtol, atol = (.01, .01)
self.assertAllClose(dQ_ref, dQ_ans, rtol=rtol, atol=atol)
self.assertAllClose(dK_ref, dK_ans, rtol=rtol, atol=atol)
self.assertAllClose(dV_ref, dV_ans, rtol=rtol, atol=atol)
self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
check_grads(nn.softplus, (1e-8,), order=4,
Expand Down
Loading