From 61beeb9cc90f32e3930dadd42bfccc49531e013f Mon Sep 17 00:00:00 2001 From: kaixih Date: Wed, 13 Mar 2024 17:09:42 +0000 Subject: [PATCH] Add cudnn attention layer --- praxis/layers/gpu_fast_attention.py | 87 +++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/praxis/layers/gpu_fast_attention.py b/praxis/layers/gpu_fast_attention.py index d46a8aed..8da1c014 100644 --- a/praxis/layers/gpu_fast_attention.py +++ b/praxis/layers/gpu_fast_attention.py @@ -25,6 +25,10 @@ import jax from jax.experimental.shard_map import shard_map +from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention +from jax import numpy as jnp +import numpy as np + from praxis import asserts from praxis import base_layer from praxis import py_utils @@ -46,6 +50,89 @@ JTensor = pytypes.JTensor +class GpuCudnnFusedDotProductAttention(attentions.DotProductAttention): + """Using Jax/Cudnn to call into a fused MHA kernel on NVIDIA GPU.""" + is_causal: bool = False + + def _shard_only_bn(self, x: JTensor) -> JTensor: + """Adds sharding annotations to tensors of shape [b, n, None, None].""" + ap = self.activation_split_dims_mapping + if self.mesh_axis_names is None or ap.blnh is None: + return x + assert len(ap.blnh) == 4 + b = [ap.blnh[0], ap.blnh[2], None, None] + return base_layer.maybe_shard(x, b, self.mesh_axis_names) + + def _dot_atten( + self, + query: JTensor, + key: JTensor, + value: JTensor, + atten_mask: JTensor, + relative_bias: JTensor | None = None, + ) -> tuple[JTensor, JTensor]: + """Main attention function. + + Args: + query: JTensor of shape [B, T, N, H]. + key: JTensor of shape [B, S, N, H]. + value: JTensor of shape [B, S, N, H]. + atten_mask: JTensor of shape [1|B, 1, 1|T, S] which is a mask that is + applied to prevent attention between unwanted pairs. This has already + been converted into large negative logits. Note that the first and third + dimension allow size 1 if the mask is shared by every item in the batch + or every token in the target sequence. + relative_bias: Relative bias of shape [B, N, T, S]. + + Returns: + encoded: JTensor of shape [B, T, N, H]. + atten_probs: JTensor of shape [B, N, T, S]. + """ + query = self._shard_blnh(query) + key = self._shard_blnh(key) + value = self._shard_blnh(value) + + b, s, n, h = key.shape + base_layer.assert_has_shape(value, [b, s, n, h]) + base_layer.assert_has_shape(query, [b, -1, n, h]) + t = query.shape[1] + # If only padding bias is supplied, then atten_mask can be [B, 1, 1, S] + # since each target token is prohibited from attending to the same set of + # source tokens. In this case tiling is inefficient and unnecessary. + # If there is no padding mask, and only causal mask then the shape can be + # [1, 1, T, S] + base_layer.assert_has_shape(atten_mask, [-1, 1, -1, s]) + asserts.in_set(atten_mask.shape[2], [t, 1]) + asserts.in_set(atten_mask.shape[0], [b, 1]) + + assert self.attention_extra_logit is None + assert not self.zero_fully_masked + assert not self.atten_logit_cap or self.atten_logit_cap <= 0.0 + + query = self._scale_query(query) + logits_scale = 1.0 / np.sqrt(h) if self.scale_logits_by_head_dims else 1.0 + + # Explicitly shard the relative_bias to ensure it has the same sharding on + # batch and num_head dim with the query. This is required by the + # dot_product_attention. + if relative_bias is not None: + relative_bias = self._shard_only_bn(relative_bias) + + # We manually transpose the inputs to BNTH due to a convergence issue when + # directly passing BTNH inputs. + # TODO(kaixih@nvidia) Remove the transpose when the issue is fixed. + query = jnp.einsum('BTNH->BNTH', query) + key = jnp.einsum('BSNH->BNSH', key) + value = jnp.einsum('BSNH->BNSH', value) + encoded = dot_product_attention( + query, key, value, relative_bias, scale=logits_scale, + is_causal_mask=self.is_causal, dropout_rate=self.atten_dropout_prob, + qkv_layout='BNTH', is_training=not self.do_eval, + ) + encoded = jnp.einsum('BNTH->BTNH', encoded) + encoded = self._shard_blnh(encoded) + return encoded, None + class GpuTritonFusedDotProductAttention(attentions.DotProductAttention): """Using Jax/Pallas/Triton to call into a fused MHA kernel on NVIDIA GPU."""