diff --git a/fine-tune.py b/fine-tune.py index 2bbc82ce..9d2e37fc 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -24,6 +24,7 @@ from torch.utils.data import Dataset from transformers import Trainer, DataCollatorForLanguageModeling from llama_attn_replace import replace_llama_attn +from gptneox_attn_replace import replace_gpt_neox_attn from peft import LoraConfig, get_peft_model from torch.distributed import barrier @@ -39,7 +40,8 @@ @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") + model_type: Optional[str] = field(default="gpt-neox") @dataclass class TrainingArguments(transformers.TrainingArguments): @@ -99,7 +101,11 @@ def train(): parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) model_args, training_args = parser.parse_args_into_dataclasses() - replace_llama_attn(training_args.use_flash_attn) + # NOTE: May expand supported model types in the future + if model_args.model_type == "gpt-neox": + replace_gpt_neox_attn(training_args.use_flash_attn) + else: + replace_llama_attn(training_args.use_flash_attn) # Set RoPE scaling factor config = transformers.AutoConfig.from_pretrained( @@ -117,6 +123,7 @@ def train(): model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, + torch_dtype=torch.bfloat16, ) tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -124,7 +131,7 @@ def train(): cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", - use_fast=False, + use_fast=True, ) special_tokens_dict = dict() @@ -157,10 +164,16 @@ def train(): data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) if training_args.low_rank_training: + if model_args.model_type == "gpt-neox": + # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' + targets = ["query_key_value", "dense"] + else: + targets=["q_proj", "k_proj", "v_proj", "o_proj"], + config = LoraConfig( r=8, lora_alpha=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + target_modules=targets, lora_dropout=0, bias="none", task_type="CAUSAL_LM", @@ -169,9 +182,9 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] + model.config.use_cache = False # required for gradient checkpointing model.enable_input_require_grads() # required for gradient checkpointing model.gradient_checkpointing_enable() # enable gradient checkpointing - trainer = Trainer( model=model, tokenizer=tokenizer, args=training_args, train_dataset=dataset["train"], diff --git a/gptneox_attn_replace.py b/gptneox_attn_replace.py new file mode 100644 index 00000000..ee16bb0b --- /dev/null +++ b/gptneox_attn_replace.py @@ -0,0 +1,166 @@ +# Modified based on https://github.com/dvlab-research/LongLoRA + +from typing import Optional, Tuple +import warnings +import torch +import transformers + +from einops import rearrange +from flash_attn import flash_attn_varlen_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +group_size_ratio = 1/4 + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1).to(q.dtype), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1).to(k.dtype), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def _flash_attn(query, key, value, attention_mask=None, head_mask=None): + # transform the data into the qkv packed form + qkv = torch.stack( + [query, key, value], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + bsz, q_len = qkv.shape[:2] + + qkv = rearrange(qkv, "b s ... -> (b s) ...") + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) + output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, q_len, 0.0, softmax_scale=None, causal=True) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + + # disable attn weights by returning None when using flash attention + return output, None + + +def get_forward_function(use_flash_attn=True, use_full=False): + + def forward_attention( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + # NOTE: compute SS group size + bsz, q_len, _ = hidden_states.size() + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] + # --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + # [bsz, nh, q_len, hd] + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + if has_layer_past: + seq_len += layer_past[0].shape[-2] + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # NOTE: apply shift + group_size = int(q_len * group_size_ratio) + if q_len % group_size > 0: + raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) + num_group = q_len // group_size + if self.training and not use_full: + def shift(qkv, num_heads, head_dim): + # qkv = [bsz, nh, q_len, d] + qkv = qkv.transpose(1, 2) + # qkv = [bsz, q_len, nh, d] + qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) + + # -> [bsz * n_group, group_s, nh, d) + # -> [bsz * n_group, nh, group_s, d) + qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) + return qkv + + # contiguous is required as self._attn() will attempt to apply .view() on them + query = shift(query, self.num_attention_heads, self.head_size).contiguous() + key = shift(key, self.num_attention_heads, self.head_size).contiguous() + value = shift(value, self.num_attention_heads, self.head_size).contiguous() + + attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) + + # Compute attention + if use_flash_attn: + attn_output, attn_weights = _flash_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + # NOTE: shift back + if self.training and not use_full: + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size) + # [bsz, q_len, nh, hd] + attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) + attn_output = attn_output.transpose(1, 2) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + return forward_attention + + +def replace_gpt_neox_attn(use_flash_attn=True, use_full=False): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if use_flash_attn and cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + "Resorting to plain attention..." + ) + use_flash_attn = False + + forward_fn = get_forward_function(use_flash_attn, use_full) + transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = forward_fn diff --git a/supervised-fine-tune.py b/supervised-fine-tune.py index bf22029a..fefbc2b2 100644 --- a/supervised-fine-tune.py +++ b/supervised-fine-tune.py @@ -27,6 +27,7 @@ from torch.utils.data import Dataset from transformers import Trainer, DataCollatorForLanguageModeling from llama_attn_replace import replace_llama_attn +from gptneox_attn_replace import replace_gpt_neox_attn from peft import LoraConfig, get_peft_model from torch.distributed import barrier @@ -64,7 +65,8 @@ def jload(f, mode="r"): @dataclass class ModelArguments: - model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") + model_type: Optional[str] = field(default="gpt-neox") @dataclass @@ -219,7 +221,11 @@ def train(): parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() - replace_llama_attn(training_args.use_flash_attn, True) + # NOTE: May expand supported model types in the future + if model_args.model_type == "gpt-neox": + replace_gpt_neox_attn(training_args.use_flash_attn) + else: + replace_llama_attn(training_args.use_flash_attn) # Set RoPE scaling factor config = transformers.AutoConfig.from_pretrained( @@ -237,6 +243,7 @@ def train(): model_args.model_name_or_path, config=config, cache_dir=training_args.cache_dir, + torch_dtype=torch.bfloat16, ) tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -244,7 +251,7 @@ def train(): cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", - use_fast=False, + use_fast=True, ) special_tokens_dict = dict() @@ -266,10 +273,16 @@ def train(): data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) if training_args.low_rank_training: + if model_args.model_type == "gpt-neox": + # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' + targets = ["query_key_value", "dense"] + else: + targets=["q_proj", "k_proj", "v_proj", "o_proj"], + config = LoraConfig( r=8, lora_alpha=16, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + target_modules=targets, lora_dropout=0, bias="none", task_type="CAUSAL_LM", @@ -278,6 +291,7 @@ def train(): # enable trainable params [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] + model.config.use_cache = False # required for gradient checkpointing model.enable_input_require_grads() # required for gradient checkpointing model.gradient_checkpointing_enable() # enable gradient checkpointing