Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monkeypatch flash attention in for llama #520

Merged
merged 16 commits into from
Aug 15, 2023
17 changes: 17 additions & 0 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from llmfoundry.models.utils import init_empty_weights

try:
Expand Down Expand Up @@ -178,6 +180,21 @@ def __init__(
f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}'
)

attention_patch_type = om_model_config.get('attention_patch_type', None)
if attention_patch_type is not None:
if model.config.model_type != 'llama':
raise ValueError(
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)

print(
f'Patching llama attention with {attention_patch_type} attention'
)
from transformers.models.llama.modeling_llama import LlamaAttention
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False

composer_model = super().__init__(model=model,
shift_labels=True,
tokenizer=tokenizer,
Expand Down
283 changes: 283 additions & 0 deletions llmfoundry/models/layers/llama_attention_monkeypatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# This file is copied and modified from
# https://github.com/huggingface/transformers/blob/fe3c8ab1af558b95f67f5fafc0c55f09fd2b09db/src/transformers/models/llama/modeling_llama.py
# See the clearly denoted code blocks for the main modifications (there are a few others like type ignores, and error messages)

import logging
from typing import Callable, Optional, Tuple

import torch
import torch.functional as F

from llmfoundry.models.layers.attention import (
scaled_multihead_dot_product_attention, triton_flash_attn_fn)

log = logging.getLogger(__name__)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Equivalent of torch.repeat_interleave(x, dim=1,

repeats=n_rep).

The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :,
None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
head_dim)


def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, position_ids: torch.Tensor):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def get_llama_attention_patch_fn(patch_fn_name: str = 'torch') -> Callable:
if patch_fn_name == 'torch':
return llama_attention_patch_torch
elif patch_fn_name == 'triton':
return llama_attention_patch_triton
else:
raise ValueError(
f'Unrecognized llama attention patch function: {patch_fn_name}')
vchiley marked this conversation as resolved.
Show resolved Hide resolved


def llama_attention_patch_torch(
self, # type: ignore
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
'use_cache is not yet supported when patching Llama attention.')

bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads *
self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp,
dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [
F.linear( # type: ignore (thirdParty)
hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [
F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty)
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [
F.linear( # type: ignore (thirdParty)
hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin,
position_ids) # type: ignore (thirdParty)

### MAIN MODIFICATIONS START HERE ###
query_states = query_states.transpose(1, 2).view(
bsz, q_len, self.num_heads * self.head_dim)
key_states = key_states.transpose(1, 2).view(
bsz, q_len, self.num_key_value_heads * self.head_dim)
value_states = value_states.transpose(1, 2).view(
bsz, q_len, self.num_key_value_heads * self.head_dim)
vchiley marked this conversation as resolved.
Show resolved Hide resolved

attn_output, attn_weights, _ = scaled_multihead_dot_product_attention(
query=query_states,
key=key_states,
value=value_states,
n_heads=self.num_heads,
kv_n_heads=self.num_key_value_heads,
past_key_value=None,
softmax_scale=None,
attn_bias=attention_mask,
key_padding_mask=None,
is_causal=False, # The causal mask is propagated from LLamaForCausalLM
dropout_p=0,
training=self.training,
needs_weights=False,
)
### MAIN MODIFICATIONS END HERE ###

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size //
self.config.pretraining_tp,
dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size //
self.config.pretraining_tp,
dim=1)
attn_output = sum([
F.linear( # type: ignore (thirdParty)
attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
])
else:
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, None # type: ignore (thirdParty)


def llama_attention_patch_triton(
self, # type: ignore
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
'use_cache is not yet supported when patching Llama attention.')
# output_attentions is not support for triton attention
if output_attentions:
raise NotImplementedError(
'output_attentions is not supported when patching Llama attention with triton attention.'
)
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads *
self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp,
dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [
F.linear( # type: ignore (thirdParty)
hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [
F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty)
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [
F.linear( # type: ignore (thirdParty)
hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin,
position_ids) # type: ignore (thirdParty)

### MAIN MODIFICATIONS START HERE ###
query_states = query_states.transpose(1, 2).view(
bsz, q_len, self.num_heads * self.head_dim)
key_states = key_states.transpose(1, 2).view(
bsz, q_len, self.num_key_value_heads * self.head_dim)
value_states = value_states.transpose(1, 2).view(
bsz, q_len, self.num_key_value_heads * self.head_dim)

attn_output, _, _ = triton_flash_attn_fn(
query=query_states,
key=key_states,
value=value_states,
n_heads=self.num_heads,
kv_n_heads=self.num_key_value_heads,
past_key_value=None,
softmax_scale=None,
attn_bias=attention_mask,
key_padding_mask=None,
is_causal=False, # The causal mask is propagated from LLamaForCausalLM
vchiley marked this conversation as resolved.
Show resolved Hide resolved
dropout_p=0,
training=self.training,
needs_weights=False,
)
### MAIN MODIFICATIONS END HERE ###

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size //
self.config.pretraining_tp,
dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size //
self.config.pretraining_tp,
dim=1)
attn_output = sum([
F.linear( # type: ignore (thirdParty)
attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
])
else:
attn_output = self.o_proj(attn_output)

return attn_output, None, None # type: ignore (thirdParty)
7 changes: 7 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def main(cfg: DictConfig):
# Check for incompatibilities between the model and data loaders
validate_config(cfg)

max_split_size_mb = cfg.get('max_split_size_mb', None)
if max_split_size_mb is not None:
os.environ[
'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}'

# Filter deprecation warning from torch internal usage
warnings.filterwarnings(
action='ignore',
Expand Down Expand Up @@ -323,6 +328,8 @@ def main(cfg: DictConfig):
dist_timeout=cfg.dist_timeout,
)

torch.cuda.empty_cache()

print('Logging config...')
log_config(cfg)

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@
]

install_requires = [
'mosaicml[libcloud,nlp,wandb,mlflow]>=0.15.0,<0.16',
'mosaicml[libcloud,wandb,mlflow]>=0.15.0,<0.16',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.31,<4.32',
'mosaicml-streaming>=0.5.1,<0.6',
'torch>=1.13.1,<=2.0.1',
'datasets==2.10.1',
Expand Down
Loading