From 80902c518090208ecb26d34077cf5cf0de623661 Mon Sep 17 00:00:00 2001 From: Xialie Zhuang <62231346+ZhuangXialie@users.noreply.github.com> Date: Tue, 23 Jul 2024 18:12:31 +0800 Subject: [PATCH 1/2] Add files via upload full train --- full_supervised_finetuning.py | 926 ++++++++++++++++++++++++++++++++++ 1 file changed, 926 insertions(+) create mode 100644 full_supervised_finetuning.py diff --git a/full_supervised_finetuning.py b/full_supervised_finetuning.py new file mode 100644 index 0000000..9fe25dd --- /dev/null +++ b/full_supervised_finetuning.py @@ -0,0 +1,926 @@ +# -*- coding: utf-8 -*- +""" +@author:Xialie Zhuang(1832963123@qq.com) +@description: Train a model from base model using Full train +""" +import math +import os +from dataclasses import dataclass, field +from glob import glob +from types import MethodType +from typing import Literal, Optional, Tuple + +import torch +import torch.nn as nn +from datasets import load_dataset +from loguru import logger +from transformers import ( + AutoConfig, + BloomForCausalLM, + AutoModel, + AutoModelForCausalLM, + LlamaForCausalLM, + BloomTokenizerFast, + AutoTokenizer, + HfArgumentParser, + Trainer, + Seq2SeqTrainingArguments, + set_seed, + BitsAndBytesConfig, + DataCollatorForSeq2Seq, +) +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + apply_rotary_pos_emb, + repeat_kv, + LlamaFlashAttention2, + Cache +) +from transformers.trainer import TRAINING_ARGS_NAME +from transformers.trainer_pt_utils import LabelSmoother +from transformers.utils.versions import require_version + +try: + from transformers.integrations import is_deepspeed_zero3_enabled +except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1 + from transformers.deepspeed import is_deepspeed_zero3_enabled + +is_flash_attn_2_available = False +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input + + is_flash_attn_2_available = True +except ImportError: + is_flash_attn_2_available = False + +from template import get_conv_template + +MODEL_CLASSES = { + "bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast), + "chatglm": (AutoConfig, AutoModel, AutoTokenizer), + "llama": (AutoConfig, LlamaForCausalLM, AutoTokenizer), + "baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer), + "auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer), +} + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_type: str = field( + default=None, + metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())} + ) + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." + ) + }, + ) + load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."}) + load_in_4bit: bool = field(default=False, metadata={"help": "Whether to load the model in 4bit mode or not."}) + tokenizer_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The tokenizer for weights initialization.Don't set if you want to train a model from scratch." + ) + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + model_revision: Optional[str] = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + hf_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}) + use_fast_tokenizer: bool = field( + default=False, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + torch_dtype: Optional[str] = field( + default="float16", + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + device_map: Optional[str] = field( + default=None, + metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "}, + ) + trust_remote_code: bool = field( + default=True, + metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."}, + ) + rope_scaling: Optional[Literal["linear", "dynamic"]] = field( + default=None, + metadata={"help": "Adopt scaled rotary positional embeddings."} + ) + flash_attn: Optional[bool] = field( + default=False, + metadata={"help": "Enable FlashAttention-2 for faster training."} + ) + shift_attn: Optional[bool] = field( + default=False, + metadata={"help": "Enable shifted sparse attention (S^2-Attn) proposed by LongLoRA."} + ) + neft_alpha: Optional[float] = field( + default=0, + metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune. value can be 5."} + ) + + def __post_init__(self): + if self.model_type is None: + raise ValueError( + "You must specify a valid model_type to run training. Available model types are " + ", ".join( + MODEL_CLASSES.keys())) + if self.model_name_or_path is None: + raise ValueError("You must specify a valid model_name_or_path to run training.") + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file_dir: Optional[str] = field(default=None, metadata={"help": "The train jsonl data file folder."}) + validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + validation_split_percentage: Optional[int] = field( + default=1, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + + def __post_init__(self): + if self.max_train_samples is not None and 0 < self.max_train_samples <= 1000: + logger.warning("You may set max_train_samples = -1 to run all samples in production.") + + +@dataclass +class ScriptArguments: + use_peft: bool = field(default=False, metadata={"help": "Whether to use peft"}) + train_on_inputs: bool = field(default=False, metadata={"help": "Whether to train on inputs"}) + target_modules: Optional[str] = field(default="all") + lora_rank: Optional[int] = field(default=8) + lora_dropout: Optional[float] = field(default=0.05) + lora_alpha: Optional[float] = field(default=32.0) + modules_to_save: Optional[str] = field(default=None) + peft_path: Optional[str] = field(default=None, metadata={"help": "The path to the peft model"}) + qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"}) + model_max_length: int = field( + default=512, + metadata={"help": "Maximum model context length. suggest: 8192 * 4, 8192 * 2, 8192, 4096, 2048, 1024, 512"} + ) + template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."}) + + def __post_init__(self): + if self.model_max_length < 60: + raise ValueError("You must specify a valid model_max_length >= 60 to run training") + + +class SavePeftModelTrainer(Trainer): + """ + Trainer for lora models + """ + + def save_model(self, output_dir=None, _internal_call=False): + """Save the LoRA model.""" + os.makedirs(output_dir, exist_ok=True) + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + self.model.save_pretrained(output_dir) + + +def save_model(model, tokenizer, args): + """Save the model and the tokenizer.""" + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + + # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + +def save_model_zero3(model, tokenizer, args, trainer): + """Save the model for deepspeed zero3. + refer https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train_lora.py#L209 + """ + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() + model_to_save = model.module if hasattr(model, "module") else model + model_to_save.save_pretrained(args.output_dir, state_dict=state_dict_zero3) + tokenizer.save_pretrained(output_dir) + + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +def find_all_linear_names(peft_model, int4=False, int8=False): + """Find all linear layer names in the model. reference from qlora paper.""" + cls = torch.nn.Linear + if int4 or int8: + import bitsandbytes as bnb + if int4: + cls = bnb.nn.Linear4bit + elif int8: + cls = bnb.nn.Linear8bitLt + lora_module_names = set() + for name, module in peft_model.named_modules(): + if isinstance(module, cls): + # last layer is not add to lora_module_names + if 'lm_head' in name: + continue + if 'output_layer' in name: + continue + names = name.split('.') + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + return sorted(lora_module_names) + + +# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +def llama_torch_attn_forward( + self: "LlamaAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: torch.Tensor) -> torch.Tensor: + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2:].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2:].roll(groupsz // 2, dims=1), + ) + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +def llama_flash_attn_forward( + self: "LlamaFlashAttention2", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning("The input hidden states seems to be silently casted in float32.") + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: torch.Tensor) -> torch.Tensor: + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2:].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + attn_output: torch.Tensor = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2:].roll(groupsz // 2, dims=1), + ) + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def apply_llama_patch() -> None: + LlamaAttention.forward = llama_torch_attn_forward + LlamaFlashAttention2.forward = llama_flash_attn_forward + + +def main(): + parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, ScriptArguments)) + model_args, data_args, training_args, script_args = parser.parse_args_into_dataclasses() + + logger.info(f"Model args: {model_args}") + logger.info(f"Data args: {data_args}") + logger.info(f"Training args: {training_args}") + logger.info(f"Script args: {script_args}") + logger.info( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type] + # Load tokenizer + tokenizer_kwargs = { + "cache_dir": model_args.cache_dir, + "use_fast": model_args.use_fast_tokenizer, + "trust_remote_code": model_args.trust_remote_code, + } + tokenizer_name_or_path = model_args.tokenizer_name_or_path + if not tokenizer_name_or_path: + tokenizer_name_or_path = model_args.model_name_or_path + tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs) + prompt_template = get_conv_template(script_args.template_name) + if tokenizer.eos_token_id is None: + tokenizer.eos_token = prompt_template.stop_str # eos token is required + tokenizer.add_special_tokens({"eos_token": tokenizer.eos_token}) + logger.info(f"Add eos_token: {tokenizer.eos_token}, eos_token_id: {tokenizer.eos_token_id}") + if tokenizer.bos_token_id is None: + tokenizer.add_special_tokens({"bos_token": tokenizer.eos_token}) + tokenizer.bos_token_id = tokenizer.eos_token_id + logger.info(f"Add bos_token: {tokenizer.bos_token}, bos_token_id: {tokenizer.bos_token_id}") + if tokenizer.pad_token_id is None: + if tokenizer.unk_token_id is not None: + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.eos_token + logger.info(f"Add pad_token: {tokenizer.pad_token}, pad_token_id: {tokenizer.pad_token_id}") + logger.debug(f"Tokenizer: {tokenizer}") + + IGNORE_INDEX = LabelSmoother.ignore_index if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + + # Get datasets + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + ) + if "validation" not in raw_datasets.keys(): + shuffled_train_dataset = raw_datasets["train"].shuffle(seed=42) + # Split the shuffled train dataset into training and validation sets + split = shuffled_train_dataset.train_test_split( + test_size=data_args.validation_split_percentage / 100, + seed=42 + ) + # Assign the split datasets back to raw_datasets + raw_datasets["train"] = split["train"] + raw_datasets["validation"] = split["test"] + else: + # Loading a dataset from local files. + data_files = {} + if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir): + train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob( + f'{data_args.train_file_dir}/**/*.jsonl', recursive=True) + logger.info(f"train files: {train_data_files}") + data_files["train"] = train_data_files + if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir): + eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob( + f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True) + logger.info(f"eval files: {eval_data_files}") + data_files["validation"] = eval_data_files + raw_datasets = load_dataset( + 'json', + data_files=data_files, + cache_dir=model_args.cache_dir, + ) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + shuffled_train_dataset = raw_datasets["train"].shuffle(seed=42) + split = shuffled_train_dataset.train_test_split( + test_size=float(data_args.validation_split_percentage / 100), + seed=42 + ) + raw_datasets["train"] = split["train"] + raw_datasets["validation"] = split["test"] + logger.info(f"Raw datasets: {raw_datasets}") + + # Preprocessing the datasets + max_length = script_args.model_max_length + + def preprocess_function(examples): + """ + Preprocessing the datasets. + part of code modified from https://github.com/lm-sys/FastChat + """ + input_ids_list = [] + attention_mask_list = [] + targets_list = [] + roles = ["human", "gpt"] + + def get_dialog(examples): + system_prompts = examples.get("system_prompt", "") + for i, source in enumerate(examples['conversations']): + if len(source) < 2: + continue + data_role = source[0].get("from", "") + if data_role not in roles or data_role != roles[0]: + # Skip the first one if it is not from human + source = source[1:] + if len(source) < 2: + continue + messages = [] + for j, sentence in enumerate(source): + data_role = sentence.get("from", "") + if data_role not in roles: + logger.warning(f"unknown role: {data_role}, {i}. (ignored)") + break + if data_role == roles[j % 2]: + messages.append(sentence["value"]) + if len(messages) % 2 != 0: + continue + # Convert the list to pairs of elements + history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)] + system_prompt = system_prompts[i] if system_prompts else None + yield prompt_template.get_dialog(history_messages, system_prompt=system_prompt) + + for dialog in get_dialog(examples): + input_ids, labels = [], [] + + for i in range(len(dialog) // 2): + source_ids = tokenizer.encode(text=dialog[2 * i], add_special_tokens=(i == 0)) + target_ids = tokenizer.encode(text=dialog[2 * i + 1], add_special_tokens=False) + + total_len = len(source_ids) + len(target_ids) + max_source_len = int(max_length * (len(source_ids) / total_len)) + max_target_len = int(max_length * (len(target_ids) / total_len)) + + if len(source_ids) > max_source_len: + source_ids = source_ids[:max_source_len] + if len(target_ids) > max_target_len - 1: # eos token + target_ids = target_ids[:max_target_len - 1] + if len(source_ids) > 0 and source_ids[0] == tokenizer.eos_token_id: + source_ids = source_ids[1:] + if len(target_ids) > 0 and target_ids[-1] == tokenizer.eos_token_id: + target_ids = target_ids[:-1] + if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: + break + + input_ids += source_ids + target_ids + [tokenizer.eos_token_id] # add eos token for each turn + if script_args.train_on_inputs: + labels += source_ids + target_ids + [tokenizer.eos_token_id] + else: + labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] + + input_ids_list.append(input_ids) + attention_mask_list.append([1] * len(input_ids)) + targets_list.append(labels) + + return dict( + input_ids=input_ids_list, + attention_mask=attention_mask_list, + labels=targets_list, + ) + + def filter_empty_labels(example): + """Remove empty labels dataset.""" + return not all(label == IGNORE_INDEX for label in example["labels"]) + + train_dataset = None + max_train_samples = 0 + if training_args.do_train: + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets['train'].shuffle(seed=42) + max_train_samples = len(train_dataset) + if data_args.max_train_samples is not None and data_args.max_train_samples > 0: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + logger.debug(f"Example train_dataset[0]: {train_dataset[0]}") + with training_args.main_process_first(desc="Train dataset tokenization"): + train_dataset = train_dataset.shuffle().map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=train_dataset.column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on train dataset", + ) + train_dataset = train_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers) + logger.debug(f"Num train_samples: {len(train_dataset)}") + logger.debug("Tokenized training example:") + logger.debug(f"Decode input_ids[0]:\n{tokenizer.decode(train_dataset[0]['input_ids'])}") + replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id + for label in list(train_dataset[0]['labels'])] + logger.debug(f"Decode labels[0]:\n{tokenizer.decode(replaced_labels)}") + + eval_dataset = None + max_eval_samples = 0 + if training_args.do_eval: + with training_args.main_process_first(desc="Eval dataset tokenization"): + if "validation" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = raw_datasets["validation"] + max_eval_samples = len(eval_dataset) + if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) + eval_size = len(eval_dataset) + logger.debug(f"Num eval_samples: {eval_size}") + if eval_size > 500: + logger.warning(f"Num eval_samples is large: {eval_size}, " + f"training slow, consider reduce it by `--max_eval_samples=50`") + logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}") + eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=eval_dataset.column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on validation dataset", + ) + eval_dataset = eval_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers) + logger.debug(f"Num eval_samples: {len(eval_dataset)}") + logger.debug("Tokenized eval example:") + logger.debug(tokenizer.decode(eval_dataset[0]['input_ids'])) + + # Load model + if model_args.model_name_or_path: + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + ddp = world_size != 1 + if ddp: + model_args.device_map = {"": int(os.environ.get("LOCAL_RANK", "0"))} + training_args.gradient_accumulation_steps = training_args.gradient_accumulation_steps // world_size or 1 + + config_kwargs = { + "trust_remote_code": model_args.trust_remote_code, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.hf_hub_token, + } + config = config_class.from_pretrained(model_args.model_name_or_path, **config_kwargs) + + # Set RoPE scaling + if model_args.rope_scaling is not None: + if hasattr(config, "rope_scaling"): + if model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and script_args.model_max_length > current_max_length: + scaling_factor = float(math.ceil(script_args.model_max_length / current_max_length)) + else: + logger.warning(f"The model_max_length({script_args.model_max_length}) is smaller than max " + f"length({current_max_length}). Consider increase model_max_length.") + scaling_factor = 1.0 + + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info("Using {} scaling strategy and setting scaling factor to {}".format( + model_args.rope_scaling, scaling_factor + )) + else: + logger.warning("Current model does not support RoPE scaling.") + + # Set FlashAttention-2 + if model_args.flash_attn: + if is_flash_attn_2_available: + config_kwargs["use_flash_attention_2"] = True + logger.info("Using FlashAttention-2 for faster training and inference.") + else: + logger.warning("FlashAttention-2 is not installed.") + elif model_args.shift_attn and getattr(config, "model_type", None) == "llama": + logger.warning("Using `--flash_attn` for faster training in large context length, enable if your GPU" + " is RTX3090, RTX4090, A100 or H100.") + + # Set shifted sparse attention (S^2-Attn) + if model_args.shift_attn: + if getattr(config, "model_type", None) == "llama": + setattr(config, "group_size_ratio", 0.25) + apply_llama_patch() + logger.info("Using shifted sparse attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shifted sparse attention.") + + load_in_4bit = model_args.load_in_4bit + load_in_8bit = model_args.load_in_8bit + if load_in_4bit and load_in_8bit: + raise ValueError("Error, load_in_4bit and load_in_8bit cannot be set at the same time") + elif load_in_8bit or load_in_4bit: + logger.info(f"Quantizing model, load_in_4bit: {load_in_4bit}, load_in_8bit: {load_in_8bit}") + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + if load_in_8bit: + config_kwargs['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) + elif load_in_4bit: + if script_args.qlora: + config_kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch_dtype, + ) + else: + config_kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch_dtype, + ) + + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + torch_dtype=torch_dtype, + **config_kwargs, + ) + + # Fix ChatGLM2 and ChatGLM3 LM head + if getattr(config, "model_type", None) == "chatglm": + setattr(model, "lm_head", model.transformer.output_layer) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + # Set NEFTune trick for fine-tuning + if model_args.neft_alpha > 0: + input_embed = model.get_input_embeddings() + if isinstance(input_embed, torch.nn.Embedding): + def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor: + embeddings = input_embed.__class__.forward(self, x) + dims = self.num_embeddings * self.embedding_dim + mag_norm = model_args.neft_alpha / (dims ** 0.5) + embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm) + return embeddings + + input_embed.forward = MethodType(noisy_forward, input_embed) + logger.info("Using noisy embedding with alpha={:.2f}".format(model_args.neft_alpha)) + else: + logger.warning("Input embeddings are not normal nn.Embedding, cannot transform into noisy embedding.") + + # Patch Mixtral MOE model + if getattr(config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled(): + require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + from deepspeed.utils import set_z3_leaf_modules # type: ignore + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock # type: ignore + + set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + else: + raise ValueError(f"Error, model_name_or_path is None, SFT must be loaded from a pre-trained model") + + # Full parameters training + logger.info("Fine-tuning method: Full parameters training") + model = model.float() + print_trainable_parameters(model) + + # Initialize our Trainer + if training_args.gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False): + model.gradient_checkpointing_enable() + model.config.use_cache = False + logger.info("Gradient checkpointing enabled.") + else: + model.config.use_cache = True + logger.info("Gradient checkpointing disabled.") + model.enable_input_require_grads() + if not ddp and torch.cuda.device_count() > 1: + # Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available + model.is_parallelizable = True + model.model_parallel = True + + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + model=model, + label_pad_token_id=IGNORE_INDEX, + pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shifted sparse attention + ) + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + # Training + if training_args.do_train: + logger.info("*** Train ***") + if trainer.is_world_process_zero(): + sample = next(iter(trainer.get_train_dataloader())) + logger.debug(f"Train dataloader example: {sample}") + logger.debug(f"input_ids:\n{list(sample['input_ids'])[:3]}, \nlabels:\n{list(sample['labels'])[:3]}") + logger.debug(f"Decode input_ids[0]:\n{tokenizer.decode(sample['input_ids'][0])}") + replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id for label in + sample['labels'][0]] + logger.debug(f"Decode labels[0]:\n{tokenizer.decode(replaced_labels)}") + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + + metrics = train_result.metrics + metrics["train_samples"] = max_train_samples + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + model.config.use_cache = True # enable cache after training + tokenizer.padding_side = "left" # restore padding side + tokenizer.init_kwargs["padding_side"] = "left" + + if trainer.is_world_process_zero(): + logger.debug(f"Training metrics: {metrics}") + logger.info(f"Saving model checkpoint to {training_args.output_dir}") + if is_deepspeed_zero3_enabled(): + save_model_zero3(model, tokenizer, training_args, trainer) + else: + save_model(model, tokenizer, training_args) + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate(metric_key_prefix="eval") + + metrics["eval_samples"] = max_eval_samples + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + metrics["perplexity"] = perplexity + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if trainer.is_world_process_zero(): + logger.debug(f"Eval metrics: {metrics}") + + +if __name__ == "__main__": + main() From ba1febd0ff24b4524b5203109f686d057e911a77 Mon Sep 17 00:00:00 2001 From: Xialie Zhuang <62231346+ZhuangXialie@users.noreply.github.com> Date: Tue, 23 Jul 2024 18:15:47 +0800 Subject: [PATCH 2/2] Add files via upload full_train.sh --- run_full_train.sh | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 run_full_train.sh diff --git a/run_full_train.sh b/run_full_train.sh new file mode 100644 index 0000000..580c8f8 --- /dev/null +++ b/run_full_train.sh @@ -0,0 +1,35 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node 8 full_supersived_finetuning.py \ + --model_type auto \ + --cache_dir ./model \ + --model_name_or_path ./model/glm-4-9b-chat \ + --train_file_dir ./ \ + --validation_file_dir ./ \ + --per_device_train_batch_size 2 \ + --do_train \ + --num_train_epochs 15 \ + --per_device_eval_batch_size 2 \ + --max_train_samples -1 \ + --learning_rate 3e-5 \ + --warmup_ratio 0.2 \ + --model_max_length 2048 \ + --weight_decay 0.01 \ + --logging_strategy steps \ + --logging_steps 1 \ + --save_steps 400 \ + --save_strategy steps \ + --save_total_limit 3 \ + --gradient_accumulation_steps 1 \ + --preprocessing_num_workers 128 \ + --output_dir GLM4-sft-med \ + --overwrite_output_dir \ + --ddp_timeout 30000 \ + --logging_first_step True \ + --target_modules all \ + --torch_dtype bfloat16 \ + --report_to tensorboard \ + --neft_alpha 8 \ + --ddp_find_unused_parameters False \ + --gradient_checkpointing True \ + --template_name chatglm3 \ + --deepspeed ds_zero_2.json \ + --fp16