Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle flash_attn version update in transformers main #643

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from llmfoundry.models.utils import init_empty_weights

try:
Expand Down Expand Up @@ -141,6 +139,9 @@ def __init__(self, om_model_config: Union[DictConfig,
# Rank 0 will still be pretrained, and distribute the weights appropriately
if dist.get_local_rank() != 0 and init_device == 'mixed':
om_model_config.pretrained = False

if config.model_type == 'llama':
transformers.utils.is_flash_attn_available = lambda : False

# initialize the model on the correct device
if resolved_init_device == 'cpu':
Expand Down Expand Up @@ -198,6 +199,8 @@ def __init__(self, om_model_config: Union[DictConfig,
log.debug(
f'Patching llama attention with {attention_patch_type} attention'
)
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from transformers.models.llama.modeling_llama import \
LlamaAttention
LlamaAttention.forward = get_llama_attention_patch_fn(
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/layers/llama_attention_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def llama_attention_patch_triton(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
Expand Down
1 change: 0 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
process_init_device,
update_batch_size_info)


def validate_config(cfg: DictConfig):
"""Validates compatible model and dataloader selection."""
loaders = [cfg.train_loader]
Expand Down
Loading