Skip to content

Commit

Permalink
Deprecate attention patching for llama (#1047)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and KuuCi committed Apr 18, 2024
1 parent be4c012 commit 9dfe080
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 92 deletions.
48 changes: 17 additions & 31 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
cfg.load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Default: ``False``.
cfg.init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``.
cfg.attention_patch_type (str, optional): Which attention patch to use for llama models. Default: ``None``.
Deprecated. Will automatically use flash attention 2.
cfg.use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``.
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
"""
Expand All @@ -89,19 +90,30 @@ def __init__(self, om_model_config: DictConfig,
use_auth_token = om_model_config.get('use_auth_token', False)
use_flash_attention_2 = om_model_config.get('use_flash_attention_2',
False)
requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
load_in_8bit = om_model_config.get('load_in_8bit', False)
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please `pip install llm-foundry[gpu-flash2]`.')

# Set up config args for the model construction and base classes
z_loss = om_model_config.get('z_loss', 0.0)
init_device = om_model_config.get('init_device', 'cpu')
# Resolve "mixed" init device to either "cpu" or "meta"
resolved_init_device = hf_get_init_device(init_device)
attention_patch_type = om_model_config.get('attention_patch_type', None)
if attention_patch_type is not None:
warnings.warn(
VersionedDeprecationWarning(
'attention_patch_type is deprecated and will automatically use flash attention 2. '
+
'We recommend `use_flash_attention_2: true` for llama models.',
remove_version='0.7.0'))
use_flash_attention_2 = True

requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'

if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please `pip install llm-foundry[gpu-flash2]`.')

peft_config_dict = pop_config(om_model_config,
'peft_config',
must_exist=False,
Expand Down Expand Up @@ -246,9 +258,6 @@ def _autoset_attn_implementation_monkeypatch(
if dist.get_local_rank() == 0:
os.remove(signal_file_path)

if attention_patch_type is not None:
self._patch_attention_type(model, attention_patch_type)

# Hugging Face's weight tying does not succeed if the model is inited on meta device
# so we manually apply the weight tying here
if model.config.tie_word_embeddings and resolved_init_device == 'meta':
Expand Down Expand Up @@ -278,29 +287,6 @@ def _autoset_attn_implementation_monkeypatch(
peft_config=peft_config,
)

@staticmethod
def _patch_attention_type(model: PreTrainedModel,
attention_patch_type: str) -> None:
if model.config.model_type != 'llama':
raise ValueError(
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)

warnings.warn(
VersionedDeprecationWarning(
'Attention patches for Llama models are deprecated. We recommend `use_flash_attention_2: True` for Llama models.',
remove_version='0.7.0'))

log.debug(
f'Patching llama attention with {attention_patch_type} attention')
from transformers.models.llama.modeling_llama import LlamaAttention

from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False

@staticmethod
def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig':
if peft_installed:
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/utils/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
]


class VersionedDeprecationWarning(DeprecationWarning):
class VersionedDeprecationWarning(UserWarning):
"""A custom deprecation warning class that includes version information.
Attributes:
Expand Down
9 changes: 0 additions & 9 deletions scripts/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,15 +372,6 @@ model:
```
HuggingFace models currently only support Flash Attention V2.

For Llama specifically, we have another option if you would like to use the Triton implementation of Flash Attention. You can specify this in your YAML like so:
```yaml
model:
name: hf_causal_lm
pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf
attention_patch_type: triton
...
```

# FAQ: How many GPUs do I need to train a LLM? <a name="howmanygpus"></a>
This is a complicated question in general, but if we assume that you are using FSDP with `FULL_SHARD`,
activation checkpointing, and `DecoupledLionW`, then a good rule of thumb is:
Expand Down
51 changes: 0 additions & 51 deletions tests/models/layers/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import transformers
from composer.core.precision import get_precision_context
from composer.utils import reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers.models.llama.modeling_llama import LlamaAttention

Expand Down Expand Up @@ -104,56 +103,6 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)


@pytest.mark.gpu
@pytest.mark.parametrize('patch', ['triton', 'torch'])
def test_attn_patch_integration(patch: str):
if 'HUGGING_FACE_HUB_TOKEN' not in os.environ:
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
)

# Save the original attention function to restore at the end of the test.
from transformers.models.llama.modeling_llama import LlamaAttention
original_attn = LlamaAttention.forward

name = 'meta-llama/Llama-2-7b-hf'
model_cfg = DictConfig({
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': name,
'config_overrides': {
'num_hidden_layers': 2,
'intermediate_size': 64,
'hidden_size': 64,
},
'use_auth_token': True,
'pretrained': False,
'init_device': 'cpu',
'attention_patch_type': patch
})

tokenizer = build_tokenizer(name, tokenizer_kwargs={})
tokenizer.pad_token = tokenizer.eos_token

model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer)

tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'],
return_tensors='pt',
padding=True)
tokenized_input['labels'] = tokenized_input['input_ids'].clone()

tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()}
model.to('cuda')

with get_precision_context('amp_bf16'):
# We're just testing that the attention patch runs okay
outputs = model(tokenized_input)
loss = outputs.loss
loss.backward()

# Ensure the patch does not persist beyond this test.
LlamaAttention.forward = original_attn


@pytest.mark.gpu
@pytest.mark.world_size(2)
@pytest.mark.parametrize('model_name', ['llama2', 'mistral'])
Expand Down

0 comments on commit 9dfe080

Please sign in to comment.