diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 1c6e648c3..6f554ca22 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -95,7 +95,7 @@ Download a training file, such as `alpaca_data.json` for instruction tuning and curl http://${your_ip}:8015/v1/files -X POST -H "Content-Type: multipart/form-data" -F "file=@./alpaca_data.json" -F purpose="fine-tune" ``` -For reranking and embedding models finetuning, the training file [toy_finetune_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/blob/master/examples/finetune/toy_finetune_data.jsonl) is an toy example. +For reranking and embedding models finetuning, the training file [toy_finetune_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/blob/1.1/examples/finetune/toy_finetune_data.jsonl) is an toy example. ### 3.2 Create fine-tuning job @@ -199,6 +199,31 @@ Below is an example for the format of the pretraining dataset: {"text": "A boy with a blue tank top sitting watching three dogs."} ``` +#### 3.2.5 Direct Preference Optimization (DPO) + +Use the following command to launch a job for LLM Direct Preference Optimization, such as `meta-llama/Llama-2-7b-hf`: + +```bash +# create a finetuning job +curl http://${your_ip}:8015/v1/fine_tuning/jobs \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "training_file": "test_data.jsonl", + "model": "meta-llama/Llama-2-7b-hf", + "General":{ + "task":"dpo" + } + }' +``` + +Below is an example for the format of the DPO dataset: + +```json +{"system": "", "question": "You will be given a definition of a task first, then some input of the task.\nThis task is about using the specified sentence and converting the sentence to Resource Description Framework (RDF) triplets of the form (subject, predicate object). The RDF triplets generated must be such that the triplets accurately capture the structure and semantics of the input sentence. The input is a sentence and the output is a list of triplets of the form [subject, predicate, object] that capture the relationships present in the sentence. When a sentence has more than 1 RDF triplet possible, the output must contain all of them.\n\nAFC Ajax (amateurs)'s ground is Sportpark De Toekomst where Ajax Youth Academy also play.\nOutput:", "chosen": "[\n [\"AFC Ajax (amateurs)\", \"has ground\", \"Sportpark De Toekomst\"],\n [\"Ajax Youth Academy\", \"plays at\", \"Sportpark De Toekomst\"]\n]", "rejected": " Sure, I'd be happy to help! Here are the RDF triplets for the input sentence:\n\n[AFC Ajax (amateurs), hasGround, Sportpark De Toekomst]\n[Ajax Youth Academy, playsAt, Sportpark De Toekomst]\n\nExplanation:\n\n* AFC Ajax (amateurs) is the subject of the first triplet, and hasGround is the predicate that describes the relationship between AFC Ajax (amateurs) and Sportpark De Toekomst.\n* Ajax Youth Academy is the subject of the second triplet, and playsAt is the predicate that describes the relationship between Ajax Youth Academy and Sportpark De Toekomst.\n\nNote that there may be other possible RDF triplets that could be derived from the input sentence, but the above triplets capture the main relationships present in the sentence."} +{"system": "You are an AI assistant. You will be given a task. You must generate a detailed and long answer.", "question": "Generate an approximately fifteen-word sentence that describes all this data: Midsummer House eatType restaurant; Midsummer House food Chinese; Midsummer House priceRange moderate; Midsummer House customer rating 3 out of 5; Midsummer House near All Bar One", "chosen": "Midsummer House is a moderately priced Chinese restaurant with a 3/5 customer rating, located near All Bar One.", "rejected": " Sure! Here's a sentence that describes all the data you provided:\n\n\"Midsummer House is a moderately priced Chinese restaurant with a customer rating of 3 out of 5, located near All Bar One, offering a variety of delicious dishes.\""} +``` + ### 3.3 Manage fine-tuning job Below commands show how to list finetuning jobs, retrieve a finetuning job, cancel a finetuning job and list checkpoints of a finetuning job. diff --git a/comps/finetuning/finetune_config.py b/comps/finetuning/finetune_config.py index 5473cd9aa..0b2faf53d 100644 --- a/comps/finetuning/finetune_config.py +++ b/comps/finetuning/finetune_config.py @@ -58,7 +58,7 @@ def check_report_to(cls, v: str): @validator("task") def check_task(cls, v: str): - assert v in ["instruction_tuning", "pretraining", "rerank", "embedding"] + assert v in ["instruction_tuning", "pretraining", "dpo", "rerank", "embedding"] return v @@ -71,11 +71,13 @@ class DatasetConfig(BaseModel): block_size: int = 512 shuffle: bool = False max_source_length: int = 384 + max_prompt_length: int = 512 padding_side: str = "right" truncation_side: str = "right" max_seq_length: int = 512 truncation: bool = True padding: Union[bool, str] = True + pad_to_max: bool = False mask_input: bool = True mask_response: bool = True data_preprocess_type: str = "neural_chat" @@ -132,6 +134,7 @@ class TrainingConfig(BaseModel): logging_steps: int = 10 deepspeed_config_file: str = "" embedding_training_config: Optional[EmbeddingTrainingConfig] = EmbeddingTrainingConfig() + dpo_beta: float = Field(default=0.1, description="the beta parameter for DPO loss") @validator("device") def check_device(cls, v: str): diff --git a/comps/finetuning/llm_on_ray/finetune/data_process.py b/comps/finetuning/llm_on_ray/finetune/data_process.py index 07b12d71e..070f77a65 100644 --- a/comps/finetuning/llm_on_ray/finetune/data_process.py +++ b/comps/finetuning/llm_on_ray/finetune/data_process.py @@ -235,6 +235,102 @@ def tokenize(self, examples): return examples +class DPODataProcessor: + def __init__(self, config, tokenizer): + self.tokenizer = tokenizer + self.max_length = config["Dataset"].get("max_length", 1024) + self.max_prompt_length = config["Dataset"].get("max_prompt_length", 512) + self.pad_to_max = config["Dataset"].get("pad_to_max", False) + + def tokenize(self, examples): + prompts = {(system + question).strip() for system, question in zip(examples["system"], examples["question"])} + chosens = {c.strip() for c in examples["chosen"]} + rejects = {r.strip() for r in examples["rejected"]} + + examples = { + "prompt": [], + "chosen": [], + "rejected": [], + "chosen_response_only": [], + "rejected_response_only": [], + "chosen_input_ids": [], + "chosen_attention_mask": [], + "chosen_labels": [], + "rejected_input_ids": [], + "rejected_attention_mask": [], + "rejected_labels": [], + "prompt_input_ids": [], + "prompt_attention_mask": [], + } + + for prompt, chosen, reject in zip(prompts, chosens, rejects): + + prompt_tokens = self.tokenizer.tokenize(prompt) + + if len(prompt_tokens) > self.max_prompt_length: + prompt_tokens = prompt_tokens[: self.max_prompt_length] + + prompt_ids = self.tokenizer.convert_tokens_to_ids(prompt_tokens) + prompt_mask = [1] * len(prompt_ids) + + max_resp = self.max_length - len(prompt_ids) + chosen_tokens = self.tokenizer.tokenize(chosen) + chosen_tokens = chosen_tokens[: max_resp - 1] + chosen_tokens.append(self.tokenizer.eos_token) + chosen_ids = self.tokenizer.convert_tokens_to_ids(chosen_tokens) + chosen_mask = [1] * len(chosen_ids) + + reject_tokens = self.tokenizer.tokenize(reject) + reject_tokens = reject_tokens[: max_resp - 1] + reject_tokens.append(self.tokenizer.eos_token) + reject_ids = self.tokenizer.convert_tokens_to_ids(reject_tokens) + reject_mask = [1] * len(reject_ids) + + chosen_input_ids = prompt_ids + chosen_ids + chosen_attention_mask = prompt_mask + chosen_mask + chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids + + reject_input_ids = prompt_ids + reject_ids + reject_attention_mask = prompt_mask + reject_mask + reject_labels = [IGNORE_INDEX] * len(prompt_ids) + reject_ids + + # padding + input_len = len(chosen_input_ids) + if self.pad_to_max: + pad_len = self.max_length - input_len + chosen_input_ids = chosen_input_ids + [0] * pad_len + chosen_labels = chosen_labels + [-100] * pad_len + chosen_attention_mask = chosen_attention_mask + [0] * pad_len + assert len(chosen_input_ids) == self.max_length + + input_len = len(reject_input_ids) + if self.pad_to_max: + pad_len = self.max_length - input_len + reject_input_ids = reject_input_ids + [0] * pad_len + reject_labels = reject_labels + [-100] * pad_len + reject_attention_mask = reject_attention_mask + [0] * pad_len + assert len(reject_input_ids) == self.max_length + + examples["prompt"].append(prompt) + examples["chosen"].append(prompt + chosen) + examples["rejected"].append(prompt + reject) + examples["chosen_response_only"].append(chosen) + examples["rejected_response_only"].append(reject) + + examples["chosen_input_ids"].append(chosen_input_ids) + examples["chosen_attention_mask"].append(chosen_attention_mask) + examples["chosen_labels"].append(chosen_labels) + + examples["rejected_input_ids"].append(reject_input_ids) + examples["rejected_attention_mask"].append(reject_attention_mask) + examples["rejected_labels"].append(reject_labels) + + examples["prompt_input_ids"].append(prompt_ids) + examples["prompt_attention_mask"].append(prompt_mask) + + return examples + + class TrainDatasetForCE(Dataset): def __init__(self, dataset, args, tokenizer): self.dataset = dataset @@ -350,3 +446,28 @@ def __call__(self, features): return_tensors="pt", ) return {"query": q_collated, "passage": d_collated} + + +@dataclass +class DPOCollator(DataCollatorWithPadding): + def __call__(self, features) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + input_ids = [torch.tensor(ins["chosen_input_ids"]) for ins in features] + [ + torch.tensor(ins["rejected_input_ids"]) for ins in features + ] + labels = [torch.tensor(ins["chosen_labels"]) for ins in features] + [ + torch.tensor(ins["rejected_labels"]) for ins in features + ] + attention_mask = [torch.tensor(ins["chosen_attention_mask"]) for ins in features] + [ + torch.tensor(ins["rejected_attention_mask"]) for ins in features + ] + + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0) + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + ) diff --git a/comps/finetuning/llm_on_ray/finetune/dpo_trainer.py b/comps/finetuning/llm_on_ray/finetune/dpo_trainer.py new file mode 100644 index 000000000..3a210548b --- /dev/null +++ b/comps/finetuning/llm_on_ray/finetune/dpo_trainer.py @@ -0,0 +1,359 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from collections import defaultdict +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset +from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training +from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments + +from comps import CustomLogger + +logger = CustomLogger("llm_on_ray/dpo_trainer") + + +def is_peft_available(): + return importlib.util.find_spec("peft") is not None + + +def disable_dropout_in_model(model: torch.nn.Module) -> None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): # pragma: no cover + module.p = 0 + + +class DPOTrainer(Trainer): + r""" + Initialize DPOTrainer, refer: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. + Used for implicit reward computation and loss. If no + reference model is provided, the trainer will + create a reference model with the same architecture as the model to be optimized. + beta (`float`, defaults to 0.1): + The beta factor in DPO loss. Higher beta means less divergence from the initial policy. + args (`transformers.TrainingArguments`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, + the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, + given a dataset of paired sequences. + label_pad_token_id (`int`, defaults to `-100`): + The label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int`, defaults to `0`): + The padding value. This argument is required if you want to use the default data collator. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + The callbacks to use for training. + max_length (`int`, defaults to `None`): + The maximum length of the sequences in the batch. + This argument is required if you want to use the default data collator. + peft_config (`Dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, + the model will be wrapped in a PEFT model. + disable_dropout (`bool`, defaults to `True`): + Whether or not to disable dropouts in `model` and `ref_model`. + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, + beta: float = 0.1, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + label_pad_token_id: int = -100, + padding_value: int = 0, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + max_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + disable_dropout: bool = True, + ): + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + self.ref_model = ref_model + + if disable_dropout: # pragma: no cover + disable_dropout_in_model(model) + disable_dropout_in_model(self.ref_model) + + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value + + self.beta = beta + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + args.remove_unused_columns = False + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + ) + + if self.is_deepspeed_enabled: # pragma: no cover + # Read more about the issue in https://github.com/huggingface/trl/pull/687 + self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model)[0] + self.ref_model.eval() + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the DPO loss for a batch of policy and reference model log probabilities.""" + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + logits = pi_logratios - ref_logratios + + losses = -F.logsigmoid(self.beta * logits) + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards + + def _get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. + Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. + Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log + probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: # pragma: no cover + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != self.label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == self.label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + return (per_token_logps * loss_mask).sum(-1) + + def dpo_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + + len_chosen = batch["input_ids"].shape[0] // 2 + + model_kwargs = {} + + all_logits = model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + **model_kwargs, + ).logits.to(torch.float32) + + all_logps = self._get_batch_logps( + all_logits, + batch["labels"], + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def get_batch_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = self.dpo_forward(model, batch) + + with torch.no_grad(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.dpo_forward(self.ref_model, batch) + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" # pragma: no cover + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean() + + return losses.mean(), metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + num_items_in_batch=None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + + loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") + + # force log the metrics + if self.accelerator.is_main_process: # pragma: no cover + self.store_metrics(metrics, train_eval="train") + + if return_outputs: # pragma: no cover + return (loss, metrics) + return loss + + def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def log(self, logs: Dict[str, float]) -> None: + """Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" # pragma: no cover + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + # pylint: disable=E1101 + return super().log(logs) + + +try: # pragma: no cover + # pylint: disable=E0611 + from optimum.habana import GaudiConfig, GaudiTrainer # pylint: disable=E0401 + + class GaudiDPOTrainer(DPOTrainer, GaudiTrainer): + r"""Initialize habana.""" + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, + beta: float = 0.1, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + label_pad_token_id: int = -100, + padding_value: int = 0, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + max_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + disable_dropout: bool = True, + ): + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + self.ref_model = ref_model + + if disable_dropout: # pragma: no cover + disable_dropout_in_model(model) + disable_dropout_in_model(self.ref_model) + + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value + + self.beta = beta + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + args.remove_unused_columns = False + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + GaudiTrainer.__init__( + self, + model=model, + gaudi_config=gaudi_config, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + ) + if self.is_deepspeed_enabled: # pragma: no cover + # Read more about the issue in https://github.com/huggingface/trl/pull/687 + self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model)[0] + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + self.ref_model.eval() + + if args.use_hpu_graphs_for_training: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph # pylint: disable=E0611, E0401 + + ref_model = self.accelerator.unwrap_model(self.ref_model) + ref_model = wrap_in_hpu_graph(ref_model) + +except: + logger.warning("HPU device not ready.") diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py index 88562dec4..d105269a4 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -26,6 +26,8 @@ from comps.finetuning.finetune_config import FinetuneConfig from comps.finetuning.llm_on_ray import common from comps.finetuning.llm_on_ray.finetune.data_process import ( + DPOCollator, + DPODataProcessor, EmbedCollator, GroupCollator, InstructionDataProcessor, @@ -286,6 +288,18 @@ def group_texts(examples): desc=f"Grouping texts in chunks of {block_size}", ) + return tokenized_dataset + elif task == "dpo": + tokenizer.pad_token = tokenizer.eos_token if not tokenizer.pad_token else tokenizer.pad_token + processor = DPODataProcessor(config, tokenizer) + column_names = dataset["train"].column_names + tokenized_dataset = dataset.map( + processor.tokenize, + remove_columns=column_names, + batched=True, + load_from_cache_file=False, + desc="Tokenize dataset", + ) return tokenized_dataset elif task == "rerank": dataset["train"] = TrainDatasetForCE(dataset["train"], config["Dataset"], tokenizer) @@ -303,6 +317,8 @@ def prepare_data_collator(config: Dict, tokenizer): return transformers.DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 ) + elif task == "dpo": + return DPOCollator(tokenizer) elif task == "rerank": return GroupCollator(tokenizer) elif task == "embedding": @@ -321,10 +337,15 @@ def load_model(config: Dict): model_dtype = convert_dtype(config["Training"].get("mixed_precision", "no")) model_config = config["General"].get("config", {}) task = config["General"].get("task", "instruction_tuning") - if task == "instruction_tuning" or task == "pretraining": + ref_model = None + if task in ["instruction_tuning", "pretraining", "dpo"]: model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config) + if task == "dpo": + ref_model = transformers.AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=model_dtype, **model_config + ) lora_config = config["General"].get("lora_config", None) - if lora_config and task != "pretraining": + if lora_config and task == "instruction_tuning": peft_config = LoraConfig(**lora_config) model = get_peft_model(model, peft_config) elif task == "rerank": @@ -362,47 +383,93 @@ def load_model(config: Dict): model.to(dtype=model_dtype, device=torch.device(config["Training"]["device"])) - return model + return model, ref_model -def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator): +def get_trainer(config: Dict, model, ref_model, tokenizer, tokenized_dataset, data_collator): device = config["Training"]["device"] + task = config["General"].get("task", "instruction_tuning") if device in ["cpu", "gpu", "cuda"]: training_args = convert_to_training_args(TrainingArguments, config) - trainer = Trainer( - model=model, - args=training_args, - train_dataset=tokenized_dataset["train"], - eval_dataset=tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None, - tokenizer=tokenizer, - data_collator=data_collator, - ) + if task == "dpo": + lora_config = config["General"].get("lora_config", None) + peft_config = LoraConfig(**lora_config) + from comps.finetuning.llm_on_ray.finetune.dpo_trainer import DPOTrainer + + trainer = DPOTrainer( + model, + ref_model, + args=training_args, + data_collator=data_collator, + beta=config["Training"].get("dpo_beta", 0.1), + train_dataset=tokenized_dataset["train"], + eval_dataset=( + tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None + ), + tokenizer=tokenizer, + peft_config=peft_config, + max_length=config["Dataset"].get("max_length", 1024), + ) + else: + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset["train"], + eval_dataset=( + tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None + ), + tokenizer=tokenizer, + data_collator=data_collator, + ) return training_args, trainer elif device in ["hpu"]: from optimum.habana import GaudiConfig from optimum.habana.transformers import GaudiTrainer, GaudiTrainingArguments - # If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config - gaudi_config_name = config["General"].get("gaudi_config_name", None) - if gaudi_config_name is not None: - gaudi_config = GaudiConfig.from_pretrained(gaudi_config_name) - else: - gaudi_config = GaudiConfig() - gaudi_config.use_fused_adam = True - gaudi_config.use_fused_clip_norm = True - training_args = convert_to_training_args(GaudiTrainingArguments, config) - trainer = GaudiTrainer( - model=model, - args=training_args, - gaudi_config=gaudi_config, - train_dataset=tokenized_dataset["train"], - eval_dataset=tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None, - tokenizer=tokenizer, - data_collator=data_collator, - ) + + if task == "dpo": + lora_config = config["General"].get("lora_config", None) + peft_config = LoraConfig(**lora_config) + from comps.finetuning.llm_on_ray.finetune.dpo_trainer import GaudiDPOTrainer + + trainer = GaudiDPOTrainer( + model, + ref_model, + args=training_args, + data_collator=data_collator, + beta=config["Training"].get("dpo_beta", 0.1), + train_dataset=tokenized_dataset["train"], + eval_dataset=( + tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None + ), + tokenizer=tokenizer, + peft_config=peft_config, + max_length=config["Dataset"].get("max_length", 1024), + ) + else: + # If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config + gaudi_config_name = config["General"].get("gaudi_config_name", None) + if gaudi_config_name is not None: + gaudi_config = GaudiConfig.from_pretrained(gaudi_config_name) + else: + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + trainer = GaudiTrainer( + model=model, + args=training_args, + gaudi_config=gaudi_config, + train_dataset=tokenized_dataset["train"], + eval_dataset=( + tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None else None + ), + tokenizer=tokenizer, + data_collator=data_collator, + ) return training_args, trainer - return None + return None, None def train_func(config: Dict[str, Any]): @@ -428,9 +495,9 @@ def train_func(config: Dict[str, Any]): data_collator = prepare_data_collator(config, tokenizer) - model = load_model(config) + model, ref_model = load_model(config) - training_args, trainer = get_trainer(config, model, tokenizer, tokenized_dataset, data_collator) + training_args, trainer = get_trainer(config, model, ref_model, tokenizer, tokenized_dataset, data_collator) logger.info("train start") trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) diff --git a/tests/finetuning/test_finetuning.sh b/tests/finetuning/test_finetuning.sh index 41bb02224..11a544dfd 100644 --- a/tests/finetuning/test_finetuning.sh +++ b/tests/finetuning/test_finetuning.sh @@ -73,6 +73,7 @@ function validate_finetune() { HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" -X POST -H 'Content-Type: application/json' -d "$INPUT_DATA" "$URL") HTTP_STATUS=$(echo $HTTP_RESPONSE | tr -d '\n' | sed -e 's/.*HTTPSTATUS://') RESPONSE_BODY=$(echo $HTTP_RESPONSE | sed -e 's/HTTPSTATUS\:.*//g') + FINTUNING_ID=$(echo "$RESPONSE_BODY" | jq -r '.id') # Parse the JSON response purpose=$(echo "$RESPONSE_BODY" | jq -r '.purpose') @@ -96,6 +97,26 @@ function validate_finetune() { fi sleep 10s + + # check finetuning job status + URL="$URL/retrieve" + for((i=1;i<=10;i++)); + do + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" -X POST -H "Content-Type: application/json" -d '{"fine_tuning_job_id": "'$FINTUNING_ID'"}' "$URL") + echo $HTTP_RESPONSE + RESPONSE_BODY=$(echo $HTTP_RESPONSE | sed -e 's/HTTPSTATUS\:.*//g') + STATUS=$(echo "$RESPONSE_BODY" | jq -r '.status') + if [[ "$STATUS" == "succeeded" ]]; then + echo "training: succeeded." + break + elif [[ "$STATUS" == "failed" ]]; then + echo "training: failed." + exit 1 + else + echo "training: '$STATUS'" + fi + sleep 1m + done } function validate_microservice() { @@ -148,7 +169,7 @@ EOF "rerank - finetuning" \ "test-comps-finetuning-server" \ '{"id":"ft-job' \ - '{"training_file": "test_data.json","model": "BAAI/bge-reranker-base","General":{"task":"rerank","lora_config":null}}' + '{"training_file": "test_data_rerank.json","model": "BAAI/bge-reranker-base","General":{"task":"rerank","lora_config":null}}' ########################## @@ -176,8 +197,31 @@ EOF "pretrain - finetuning" \ "test-comps-finetuning-server" \ '{"id":"ft-job' \ - '{"training_file": "test_data.json","model": "facebook/opt-125m","General":{"task":"pretraining","lora_config":null}}' + '{"training_file": "test_data_pretrain.json","model": "facebook/opt-125m","General":{"task":"pretraining","lora_config":null}}' + + + ########################## + # dpo test # + ########################## + # test /v1/dataprep upload file + cat < test_data_dpo.jsonl +{"system": "", "question": "You will be given a definition of a task first, then some input of the task.\nThis task is about using the specified sentence and converting the sentence to Resource Description Framework (RDF) triplets of the form (subject, predicate object). The RDF triplets generated must be such that the triplets accurately capture the structure and semantics of the input sentence. The input is a sentence and the output is a list of triplets of the form [subject, predicate, object] that capture the relationships present in the sentence. When a sentence has more than 1 RDF triplet possible, the output must contain all of them.\n\nAFC Ajax (amateurs)'s ground is Sportpark De Toekomst where Ajax Youth Academy also play.\nOutput:", "chosen": "[\n [\"AFC Ajax (amateurs)\", \"has ground\", \"Sportpark De Toekomst\"],\n [\"Ajax Youth Academy\", \"plays at\", \"Sportpark De Toekomst\"]\n]", "rejected": " Sure, I'd be happy to help! Here are the RDF triplets for the input sentence:\n\n[AFC Ajax (amateurs), hasGround, Sportpark De Toekomst]\n[Ajax Youth Academy, playsAt, Sportpark De Toekomst]\n\nExplanation:\n\n* AFC Ajax (amateurs) is the subject of the first triplet, and hasGround is the predicate that describes the relationship between AFC Ajax (amateurs) and Sportpark De Toekomst.\n* Ajax Youth Academy is the subject of the second triplet, and playsAt is the predicate that describes the relationship between Ajax Youth Academy and Sportpark De Toekomst.\n\nNote that there may be other possible RDF triplets that could be derived from the input sentence, but the above triplets capture the main relationships present in the sentence."} +{"system": "You are an AI assistant. You will be given a task. You must generate a detailed and long answer.", "question": "Generate an approximately fifteen-word sentence that describes all this data: Midsummer House eatType restaurant; Midsummer House food Chinese; Midsummer House priceRange moderate; Midsummer House customer rating 3 out of 5; Midsummer House near All Bar One", "chosen": "Midsummer House is a moderately priced Chinese restaurant with a 3/5 customer rating, located near All Bar One.", "rejected": " Sure! Here's a sentence that describes all the data you provided:\n\n\"Midsummer House is a moderately priced Chinese restaurant with a customer rating of 3 out of 5, located near All Bar One, offering a variety of delicious dishes.\""} +EOF + validate_upload \ + "http://${ip_address}:$finetuning_service_port/v1/files" \ + "dpo - upload" \ + "test-comps-finetuning-server" \ + "fine-tune" \ + "test_data_dpo.jsonl" + # test /v1/fine_tuning/jobs + validate_finetune \ + "http://${ip_address}:$finetuning_service_port/v1/fine_tuning/jobs" \ + "dpo - finetuning" \ + "test-comps-finetuning-server" \ + '{"id":"ft-job' \ + '{"training_file": "test_data_dpo.jsonl","model": "facebook/opt-125m","General":{"task":"dpo"}}' }