diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index e3eaf3ad0c..d250239931 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -65,6 +65,7 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss): cfg.load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Default: ``False``. cfg.init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``. cfg.attention_patch_type (str, optional): Which attention patch to use for llama models. Default: ``None``. + Deprecated. Will automatically use flash attention 2. cfg.use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``. tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ @@ -89,12 +90,7 @@ def __init__(self, om_model_config: DictConfig, use_auth_token = om_model_config.get('use_auth_token', False) use_flash_attention_2 = om_model_config.get('use_flash_attention_2', False) - requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' load_in_8bit = om_model_config.get('load_in_8bit', False) - if use_flash_attention_2 and not is_flash_v2_installed(): - raise ValueError( - 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' - + 'Please `pip install llm-foundry[gpu-flash2]`.') # Set up config args for the model construction and base classes z_loss = om_model_config.get('z_loss', 0.0) @@ -102,6 +98,22 @@ def __init__(self, om_model_config: DictConfig, # Resolve "mixed" init device to either "cpu" or "meta" resolved_init_device = hf_get_init_device(init_device) attention_patch_type = om_model_config.get('attention_patch_type', None) + if attention_patch_type is not None: + warnings.warn( + VersionedDeprecationWarning( + 'attention_patch_type is deprecated and will automatically use flash attention 2. ' + + + 'We recommend `use_flash_attention_2: true` for llama models.', + remove_version='0.7.0')) + use_flash_attention_2 = True + + requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' + + if use_flash_attention_2 and not is_flash_v2_installed(): + raise ValueError( + 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' + + 'Please `pip install llm-foundry[gpu-flash2]`.') + peft_config_dict = pop_config(om_model_config, 'peft_config', must_exist=False, @@ -246,9 +258,6 @@ def _autoset_attn_implementation_monkeypatch( if dist.get_local_rank() == 0: os.remove(signal_file_path) - if attention_patch_type is not None: - self._patch_attention_type(model, attention_patch_type) - # Hugging Face's weight tying does not succeed if the model is inited on meta device # so we manually apply the weight tying here if model.config.tie_word_embeddings and resolved_init_device == 'meta': @@ -278,29 +287,6 @@ def _autoset_attn_implementation_monkeypatch( peft_config=peft_config, ) - @staticmethod - def _patch_attention_type(model: PreTrainedModel, - attention_patch_type: str) -> None: - if model.config.model_type != 'llama': - raise ValueError( - f'attention_patch_type is only supported for llama models, but got {model.config.model_type}' - ) - - warnings.warn( - VersionedDeprecationWarning( - 'Attention patches for Llama models are deprecated. We recommend `use_flash_attention_2: True` for Llama models.', - remove_version='0.7.0')) - - log.debug( - f'Patching llama attention with {attention_patch_type} attention') - from transformers.models.llama.modeling_llama import LlamaAttention - - from llmfoundry.models.layers.llama_attention_monkeypatch import \ - get_llama_attention_patch_fn - LlamaAttention.forward = get_llama_attention_patch_fn( - attention_patch_type) - model.config.use_cache = False - @staticmethod def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig': if peft_installed: diff --git a/llmfoundry/utils/warnings.py b/llmfoundry/utils/warnings.py index 72bfbd975b..892d8d5a11 100644 --- a/llmfoundry/utils/warnings.py +++ b/llmfoundry/utils/warnings.py @@ -9,7 +9,7 @@ ] -class VersionedDeprecationWarning(DeprecationWarning): +class VersionedDeprecationWarning(UserWarning): """A custom deprecation warning class that includes version information. Attributes: diff --git a/scripts/train/README.md b/scripts/train/README.md index 57e3b1947f..0d9a335848 100644 --- a/scripts/train/README.md +++ b/scripts/train/README.md @@ -372,15 +372,6 @@ model: ``` HuggingFace models currently only support Flash Attention V2. -For Llama specifically, we have another option if you would like to use the Triton implementation of Flash Attention. You can specify this in your YAML like so: -```yaml -model: - name: hf_causal_lm - pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf - attention_patch_type: triton - ... -``` - # FAQ: How many GPUs do I need to train a LLM? This is a complicated question in general, but if we assume that you are using FSDP with `FULL_SHARD`, activation checkpointing, and `DecoupledLionW`, then a good rule of thumb is: diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py index b10e9ea5d1..818136e8fa 100644 --- a/tests/models/layers/test_huggingface_flash.py +++ b/tests/models/layers/test_huggingface_flash.py @@ -10,7 +10,6 @@ import transformers from composer.core.precision import get_precision_context from composer.utils import reproducibility -from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers.models.llama.modeling_llama import LlamaAttention @@ -104,56 +103,6 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool, assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol) -@pytest.mark.gpu -@pytest.mark.parametrize('patch', ['triton', 'torch']) -def test_attn_patch_integration(patch: str): - if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: - pytest.skip( - 'The CI cluster does not have access to the Llama models, so skip this test.' - ) - - # Save the original attention function to restore at the end of the test. - from transformers.models.llama.modeling_llama import LlamaAttention - original_attn = LlamaAttention.forward - - name = 'meta-llama/Llama-2-7b-hf' - model_cfg = DictConfig({ - 'name': 'hf_causal_lm', - 'pretrained_model_name_or_path': name, - 'config_overrides': { - 'num_hidden_layers': 2, - 'intermediate_size': 64, - 'hidden_size': 64, - }, - 'use_auth_token': True, - 'pretrained': False, - 'init_device': 'cpu', - 'attention_patch_type': patch - }) - - tokenizer = build_tokenizer(name, tokenizer_kwargs={}) - tokenizer.pad_token = tokenizer.eos_token - - model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer) - - tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'], - return_tensors='pt', - padding=True) - tokenized_input['labels'] = tokenized_input['input_ids'].clone() - - tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()} - model.to('cuda') - - with get_precision_context('amp_bf16'): - # We're just testing that the attention patch runs okay - outputs = model(tokenized_input) - loss = outputs.loss - loss.backward() - - # Ensure the patch does not persist beyond this test. - LlamaAttention.forward = original_attn - - @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.parametrize('model_name', ['llama2', 'mistral'])