Skip to content

Commit

Permalink
revert the attention commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Rinard committed Oct 3, 2023
1 parent 568e61d commit 803e671
Show file tree
Hide file tree
Showing 16 changed files with 373 additions and 557 deletions.
201 changes: 40 additions & 161 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
import warnings
from typing import Any, List, Optional, Tuple
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -18,7 +18,7 @@


def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
original_is_causal: bool) -> bool:
original_is_causal: bool):
# disable causal when it is not needed
# necessary for flash & triton for generation with kv_cache
if original_is_causal and num_query_tokens != num_key_tokens:
Expand All @@ -31,23 +31,6 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
return original_is_causal


def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Perform repeat of kv heads along a particular dimension.
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
n_rep: amount of repetitions of kv_n_heads
Unlike torch.repeat_interleave, this function avoids allocating new memory.
"""
if n_rep == 1:
return hidden

b, s, kv_n_heads, d = hidden.shape

hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)

return hidden.reshape(b, s, kv_n_heads * n_rep, d)


def scaled_multihead_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -101,11 +84,8 @@ def scaled_multihead_dot_product_attention(

# grouped query case
if kv_n_heads > 1 and kv_n_heads < n_heads:
# necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function
k = repeat_kv_for_gqa(k.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)
v = repeat_kv_for_gqa(v.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)
k = k.repeat_interleave(n_heads // kv_n_heads, dim=1)
v = v.repeat_interleave(n_heads // kv_n_heads, dim=1)

if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
Expand Down Expand Up @@ -263,16 +243,10 @@ def flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
# done along the head dimension = 1
key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1)
value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads,
dim=1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -399,108 +373,6 @@ def triton_flash_attn_fn(
key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)

# multi-query case
if kv_n_heads == 1:
# necessary to repeat instead of expand tensor because
# output contains NaN in edge cases such as with head dimension = 8
key = key.repeat(1, 1, n_heads, 1)
value = value.repeat(1, 1, n_heads, 1)
# grouped query case
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func( # type: ignore
query, key, value, attn_bias, reset_is_causal, softmax_scale)

output = attn_output.view(*attn_output.shape[:2], -1) # type: ignore

return output, None, past_key_value

def xformers_attn_fn(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
torch.Tensor]]]:

try:
from xformers.ops import memory_efficient_attention
except:
raise RuntimeError(
'Please install xformers.'
)

check_valid_inputs(query, key, value)

if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
))
kv_n_heads = n_heads

if past_key_value is not None:
if len(past_key_value) != 0:
key = torch.cat([past_key_value[0], key], dim=1)
value = torch.cat([past_key_value[1], value], dim=1)

past_key_value = (key, value)

if attn_bias is not None:
# clamp to 0 necessary for torch 2.0 compile()
_s_q = max(0, attn_bias.size(2) - query.size(1))
_s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
attn_bias.expand(axis=0, query.shape.0)
if dropout_p:
raise NotImplementedError(
f'Dropout not implemented for attn_impl: triton.')
dropout_p = dropout_p if training else 0.0

if needs_weights:
raise NotImplementedError(
f'attn_impl: triton cannot return attn weights.')

if key_padding_mask is not None:
warnings.warn(
'Propagating key_padding_mask to the attention module ' +\
'and applying it within the attention module can cause ' +\
'unnecessary computation/memory usage. Consider integrating ' +\
'into attn_bias once and passing that to each attention ' +\
'module instead.'
)
b_size, s_k = key_padding_mask.shape[:2]

if attn_bias is None:
attn_bias = query.new_zeros(b_size, 1, 1, s_k)

attn_bias = attn_bias.masked_fill(
~key_padding_mask.view((b_size, 1, 1, s_k)),
torch.finfo(query.dtype).min)

query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
# multi-query case
if kv_n_heads == 1:
# necessary to repeat instead of expand tensor because
Expand All @@ -516,13 +388,14 @@ def xformers_attn_fn(query: torch.Tensor,
value = value.repeat_interleave(n_heads // kv_n_heads, dim=2)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = memory_efficient_attention( # type: ignore
query, key, value, attn_bias, p=dropout_p)
attn_output = flash_attn_func( # type: ignore
query, key, value, attn_bias, reset_is_causal, softmax_scale)

output = attn_output.view(*attn_output.shape[:2], -1) # type: ignore

return output, None, past_key_value


class GroupedQueryAttention(nn.Module):
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
Expand All @@ -545,8 +418,8 @@ def __init__(
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
verbose: int = 0,
device: Optional[str] = None,
bias: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -578,9 +451,7 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

fc_kwargs: dict[str, Any] = {
'bias': bias,
}
fc_kwargs = {}
if fc_type != 'te':
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
Expand All @@ -593,7 +464,7 @@ def __init__(
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits)
self.Wqkv._fused = (0, fuse_splits) # type: ignore

if self.qk_ln:
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
Expand All @@ -605,8 +476,21 @@ def __init__(
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'triton':
self.attn_fn = triton_flash_attn_fn
if verbose:
warnings.warn(
'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\
'it uses more memory. When training larger models this can trigger ' +\
'alloc retries which hurts performance. If encountered, we recommend ' +\
'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.'
)
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available() and verbose:
warnings.warn(
'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\
'`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\
'we recommend using `attn_impl: triton`.'
)
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')

Expand All @@ -615,7 +499,7 @@ def __init__(
self.d_model,
**fc_kwargs,
)
self.out_proj._is_residual = True
self.out_proj._is_residual = True # type: ignore

def forward(
self,
Expand All @@ -625,8 +509,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = True,
needs_weights: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
):
qkv = self.Wqkv(x)

if self.clip_qkv:
Expand Down Expand Up @@ -686,8 +569,8 @@ def __init__(
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
verbose: int = 0,
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -700,9 +583,8 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device,
bias=bias,
)
verbose=verbose,
device=device)


class MultiQueryAttention(GroupedQueryAttention):
Expand All @@ -723,8 +605,8 @@ def __init__(
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
verbose: int = 0,
device: Optional[str] = None,
bias: bool = True,
):
super().__init__(
d_model=d_model,
Expand All @@ -737,15 +619,12 @@ def __init__(
attn_pdrop=attn_pdrop,
norm_type=norm_type,
fc_type=fc_type,
device=device,
bias=bias,
)
verbose=verbose,
device=device)


def attn_bias_shape(
attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
prefix_lm: bool, causal: bool,
use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
prefix_lm: bool, causal: bool, use_sequence_id: bool):
if attn_impl == 'flash':
return None
elif attn_impl in ['torch', 'triton']:
Expand All @@ -768,7 +647,7 @@ def build_attn_bias(
causal: bool = False,
alibi: bool = False,
alibi_bias_max: int = 8,
) -> Optional[torch.Tensor]:
):
if attn_impl == 'flash':
return None
elif attn_impl in ['torch', 'triton']:
Expand All @@ -791,7 +670,7 @@ def build_attn_bias(

def gen_slopes(n_heads: int,
alibi_bias_max: int = 8,
device: Optional[torch.device] = None) -> torch.Tensor:
device: Optional[torch.device] = None):
_n_heads = 2**math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
m = m.mul(alibi_bias_max / _n_heads)
Expand All @@ -813,7 +692,7 @@ def build_alibi_bias(
alibi_bias_max: int = 8,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32,
device=device).view(1, 1, 1, seq_len)
if full:
Expand Down
Loading

0 comments on commit 803e671

Please sign in to comment.