Skip to content

Commit

Permalink
Merge pull request #32 from naubull2/main
Browse files Browse the repository at this point in the history
Add support for GPTNeoX models
  • Loading branch information
yukang2017 authored Oct 3, 2023
2 parents 61842db + 02e4c1c commit 04c8db1
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 9 deletions.
23 changes: 18 additions & 5 deletions fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils.data import Dataset
from transformers import Trainer, DataCollatorForLanguageModeling
from llama_attn_replace import replace_llama_attn
from gptneox_attn_replace import replace_gpt_neox_attn
from peft import LoraConfig, get_peft_model
from torch.distributed import barrier

Expand All @@ -39,7 +40,8 @@

@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
model_type: Optional[str] = field(default="gpt-neox")

@dataclass
class TrainingArguments(transformers.TrainingArguments):
Expand Down Expand Up @@ -99,7 +101,11 @@ def train():
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()

replace_llama_attn(training_args.use_flash_attn)
# NOTE: May expand supported model types in the future
if model_args.model_type == "gpt-neox":
replace_gpt_neox_attn(training_args.use_flash_attn)
else:
replace_llama_attn(training_args.use_flash_attn)

# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
Expand All @@ -117,14 +123,15 @@ def train():
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
use_fast=True,
)

special_tokens_dict = dict()
Expand Down Expand Up @@ -157,10 +164,16 @@ def train():
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

if training_args.low_rank_training:
if model_args.model_type == "gpt-neox":
# added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
targets = ["query_key_value", "dense"]
else:
targets=["q_proj", "k_proj", "v_proj", "o_proj"],

config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
target_modules=targets,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
Expand All @@ -169,9 +182,9 @@ def train():
# enable trainable params
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]

model.config.use_cache = False # required for gradient checkpointing
model.enable_input_require_grads() # required for gradient checkpointing
model.gradient_checkpointing_enable() # enable gradient checkpointing

trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args,
train_dataset=dataset["train"],
Expand Down
166 changes: 166 additions & 0 deletions gptneox_attn_replace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Modified based on https://github.com/dvlab-research/LongLoRA

from typing import Optional, Tuple
import warnings
import torch
import transformers

from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input


group_size_ratio = 1/4

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1).to(q.dtype), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1).to(k.dtype), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def _flash_attn(query, key, value, attention_mask=None, head_mask=None):
# transform the data into the qkv packed form
qkv = torch.stack(
[query, key, value], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
bsz, q_len = qkv.shape[:2]

qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, q_len, 0.0, softmax_scale=None, causal=True)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)

# disable attn weights by returning None when using flash attention
return output, None


def get_forward_function(use_flash_attn=True, use_full=False):

def forward_attention(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
# NOTE: compute SS group size
bsz, q_len, _ = hidden_states.size()
has_layer_past = layer_past is not None

# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)

# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape)

# [batch, seq_len, num_attention_heads, 3 * head_size]
# --> 3 [batch, num_attention_heads, seq_len, head_size]
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# [bsz, nh, q_len, hd]

# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :]
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]

# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)

# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None

# NOTE: apply shift
group_size = int(q_len * group_size_ratio)
if q_len % group_size > 0:
raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
num_group = q_len // group_size
if self.training and not use_full:
def shift(qkv, num_heads, head_dim):
# qkv = [bsz, nh, q_len, d]
qkv = qkv.transpose(1, 2)
# qkv = [bsz, q_len, nh, d]
qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1)

# -> [bsz * n_group, group_s, nh, d)
# -> [bsz * n_group, nh, group_s, d)
qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2)
return qkv

# contiguous is required as self._attn() will attempt to apply .view() on them
query = shift(query, self.num_attention_heads, self.head_size).contiguous()
key = shift(key, self.num_attention_heads, self.head_size).contiguous()
value = shift(value, self.num_attention_heads, self.head_size).contiguous()

attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)

# Compute attention
if use_flash_attn:
attn_output, attn_weights = _flash_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

# NOTE: shift back
if self.training and not use_full:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size)
# [bsz, q_len, nh, hd]
attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1)
attn_output = attn_output.transpose(1, 2)

# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
attn_output = self.dense(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs

return forward_attention


def replace_gpt_neox_attn(use_flash_attn=True, use_full=False):
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if use_flash_attn and cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
"Resorting to plain attention..."
)
use_flash_attn = False

forward_fn = get_forward_function(use_flash_attn, use_full)
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = forward_fn
22 changes: 18 additions & 4 deletions supervised-fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch.utils.data import Dataset
from transformers import Trainer, DataCollatorForLanguageModeling
from llama_attn_replace import replace_llama_attn
from gptneox_attn_replace import replace_gpt_neox_attn
from peft import LoraConfig, get_peft_model
from torch.distributed import barrier

Expand Down Expand Up @@ -64,7 +65,8 @@ def jload(f, mode="r"):

@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
model_type: Optional[str] = field(default="gpt-neox")


@dataclass
Expand Down Expand Up @@ -219,7 +221,11 @@ def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

replace_llama_attn(training_args.use_flash_attn, True)
# NOTE: May expand supported model types in the future
if model_args.model_type == "gpt-neox":
replace_gpt_neox_attn(training_args.use_flash_attn)
else:
replace_llama_attn(training_args.use_flash_attn)

# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
Expand All @@ -237,14 +243,15 @@ def train():
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
use_fast=True,
)

special_tokens_dict = dict()
Expand All @@ -266,10 +273,16 @@ def train():
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)

if training_args.low_rank_training:
if model_args.model_type == "gpt-neox":
# added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
targets = ["query_key_value", "dense"]
else:
targets=["q_proj", "k_proj", "v_proj", "o_proj"],

config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
target_modules=targets,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
Expand All @@ -278,6 +291,7 @@ def train():
# enable trainable params
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]

model.config.use_cache = False # required for gradient checkpointing
model.enable_input_require_grads() # required for gradient checkpointing
model.gradient_checkpointing_enable() # enable gradient checkpointing

Expand Down

0 comments on commit 04c8db1

Please sign in to comment.