From 03878a58b7510074feb01dc0b0611d335e9140c5 Mon Sep 17 00:00:00 2001 From: Kayky Ramos Date: Thu, 11 Jul 2024 15:54:02 +0000 Subject: [PATCH 1/2] Implements SPPO Alignment Algoritm --- examples/llama-3/sppo-qlora-8b.yml | 74 ++ src/axolotl/core/trainer_builder.py | 76 +- src/axolotl/custom/trainers/SPPOTrainer.py | 1005 +++++++++++++++++ src/axolotl/prompt_strategies/dpo/chatml.py | 20 + .../config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/trainer.py | 2 +- 6 files changed, 1176 insertions(+), 2 deletions(-) create mode 100644 examples/llama-3/sppo-qlora-8b.yml create mode 100644 src/axolotl/custom/trainers/SPPOTrainer.py diff --git a/examples/llama-3/sppo-qlora-8b.yml b/examples/llama-3/sppo-qlora-8b.yml new file mode 100644 index 0000000000..e4454e619a --- /dev/null +++ b/examples/llama-3/sppo-qlora-8b.yml @@ -0,0 +1,74 @@ +base_model: meta-llama/Meta-Llama-3-8B-Instruct +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: false +load_in_4bit: true +strict: false + +rl: sppo +rl_beta: 0.1 +datasets: + - path: orion-research/Aura-SPPO-Iter1_score + type: chatml.sppo_argilla_chat +dataset_prepared_path: +val_set_size: 0 + +output_dir: ./outputs/out/Meta-Llama-3-8B-Instruct-SPPO-Iter1 +dataset_prepared_path: last_run_prepared + +adapter: qlora +lora_model_dir: + +sequence_len: 1024 +sample_packing: false +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_target_modules: +peft_use_dora: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 +optimizer: paged_adamw_32bit +lr_scheduler: cosine +learning_rate: 2.0e-4 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +save_safetensors: true +warmup_steps: 50 +evals_per_epoch: 1 +eval_max_new_tokens: 128 +eval_table_size: +save_steps: 100 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: "<|end_of_text|>" diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ec175454e9..7f862ea5df 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,6 +30,7 @@ ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled +from axolotl.custom.trainers.SPPOTrainer import SPPOTrainer from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length @@ -244,6 +245,11 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): DPO config for DPO training """ +@dataclass +class AxolotlSPPOConfig(AxolotlTrainingMixins, DPOConfig): + """ + DPO config for DPO training + """ @dataclass class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): @@ -897,6 +903,65 @@ def tokenize_row( res[key] = res[key][1:] return res +class AxolotlSPPOTrainer(SPPOTrainer): + """ + Extend the base SPPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "sppo"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.optimizer = None + + def create_optimizer(self): + if self.args.loraplus_lr_ratio is None: + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: # pylint: disable=access-member-before-definition + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + if loraplus_lr_ratio: + print("Using lora+") + loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + optimizer_kwargs, + loraplus_lr_ratio, + loraplus_lr_embedding, + ) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer + + @wraps(DPOTrainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + + return super().push_to_hub(*args, **kwargs) + + def tokenize_row( + self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None + ) -> Dict: + res = super().tokenize_row(feature, model=model) + if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None: + for key in res.keys(): + res[key] = res[key][1:] + return res class AxolotlORPOTrainer(ORPOTrainer): """ @@ -1521,7 +1586,7 @@ def build_collator( class HFRLTrainerBuilder(TrainerBuilderBase): """ - Trainer factory class for DPO Trainer + Trainer factory class for DPO/SPPO Trainer """ def get_callbacks(self): @@ -1690,6 +1755,15 @@ def build(self, total_num_steps): dpo_trainer_kwargs["generate_during_eval"] = True if self.cfg.rl == "dpo": dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + if self.cfg.rl in ["sppo"]: + trainer_cls = AxolotlSPPOTrainer + dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1 + trainer_cls_args = [self.model, self.model_ref] + dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len + dpo_trainer_kwargs["max_target_length"] = None + dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len + dpo_trainer_kwargs["generate_during_eval"] = True + dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model] diff --git a/src/axolotl/custom/trainers/SPPOTrainer.py b/src/axolotl/custom/trainers/SPPOTrainer.py new file mode 100644 index 0000000000..f2b4f967fb --- /dev/null +++ b/src/axolotl/custom/trainers/SPPOTrainer.py @@ -0,0 +1,1005 @@ +import inspect +import random +import warnings +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from copy import deepcopy +from functools import wraps +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate.utils import is_deepspeed_available, tqdm +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from trl.import_utils import is_peft_available, is_wandb_available +from trl.models import PreTrainedModelWrapper, create_reference_model +from trl.trainer.utils import ( + DPODataCollatorWithPadding, + disable_dropout_in_model, + pad_to_length, + trl_sanitze_kwargs_for_tagging, + peft_module_casting_to_bf16, +) + +import bitsandbytes as bnb + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_wandb_available(): + import wandb + +if is_deepspeed_available(): + import deepspeed + +class SPPOTrainer(Trainer): + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + beta: float = 0.1, + label_smoothing: float = 0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid", + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + label_pad_token_id: int = -100, + padding_value: int = 0, + truncation_mode: str = "keep_end", + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + max_prompt_length: Optional[int] = None, + max_target_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + is_encoder_decoder: Optional[bool] = None, + disable_dropout: bool = True, + generate_during_eval: bool = False, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + precompute_ref_log_probs: bool = False, + model_init_kwargs: Optional[Dict] = None, + ref_model_init_kwargs: Optional[Dict] = None, + model_adapter_name: str = None, + ref_adapter_name: str = None, + dataset_num_proc: str = None + ): + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.") + + if ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model, str): + raise ValueError( + "You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained( + model, + quantization_config=bnb.QuantizationConfig(bits=4), + **model_init_kwargs + ) + + if isinstance(ref_model, str): + warnings.warn( + "You passed a ref model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM`" + ) + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + elif getattr(args, "gradient_checkpointing", False): + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + self._peft_has_been_casted_to_bf16 = True + + + elif getattr(args, "gradient_checkpointing", False): + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if generate_during_eval and not is_wandb_available(): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases to be installed." + " Please install `wandb` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = model_adapter_name + self.ref_adapter_name = ref_adapter_name + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or precompute_ref_log_probs: + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if tokenizer is None: + raise ValueError("tokenizer must be specified to tokenize a DPO dataset.") + if max_length is None: + warnings.warn( + "`max_length` is not set in the DPOTrainer's init" + " it will default to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "`max_prompt_length` is not set in the DPOTrainer's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + + if max_target_length is None and self.is_encoder_decoder: + warnings.warn( + "When using an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=tokenizer.pad_token_id, + label_pad_token_id=label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + if disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.max_length = max_length + self.generate_during_eval = generate_during_eval + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = truncation_mode + self.max_target_length = max_target_length + self.tokenizer = tokenizer + self.precompute_ref_log_probs = precompute_ref_log_probs + + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0: + warnings.warn( + "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." + ) + + self.beta = beta + self.label_smoothing = label_smoothing + self.loss_type = loss_type + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + train_dataset = train_dataset.map(self.tokenize_row) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError( + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + def get_train_dataloader(self) -> DataLoader: + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) + + reference_chosen_logps = [] + reference_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) + reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics( + (reference_chosen_logp, reference_rejected_logp) + ) + reference_chosen_logps.append(reference_chosen_logp.cpu()) + reference_rejected_logps.append(reference_rejected_logp.cpu()) + + all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() + all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() + + self.train_dataset = self.train_dataset.add_column( + name="reference_chosen_logps", column=all_reference_chosen_logps + ) + self.train_dataset = self.train_dataset.add_column( + name="reference_rejected_logps", column=all_reference_rejected_logps + ) + + self._precomputed_train_ref_log_probs = True + + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_eval_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + + reference_chosen_logps = [] + reference_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) + reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics( + (reference_chosen_logp, reference_rejected_logp) + ) + reference_chosen_logps.append(reference_chosen_logp.cpu()) + reference_rejected_logps.append(reference_rejected_logp.cpu()) + + all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() + all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() + + eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps) + eval_dataset = eval_dataset.add_column( + name="reference_rejected_logps", column=all_reference_rejected_logps + ) + + if self.eval_dataset is not None: + self.eval_dataset = eval_dataset + self._precomputed_eval_ref_log_probs = True + + return super().get_eval_dataloader(eval_dataset=eval_dataset) + + def build_tokenized_answer(self, prompt, answer): + full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + response_token_ids_start_idx = len(prompt_input_ids) + + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict: + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + if not self.is_encoder_decoder: + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + num_diff_tokens = sum( + [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])] + ) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError( + "Chosen and rejected prompt_input_ids might only differ on the " + "last token due to tokenizer merge ops." + ) + + prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] + chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] + rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"] + + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + + chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) + chosen_tokens["attention_mask"].append(1) + + rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) + rejected_tokens["attention_mask"].append(1) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] + + chosen_sequence_tokens = { + k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(chosen_tokens["prompt_input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(rejected_tokens["prompt_input_ids"]) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.tokenizer( + chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True + ) + rejected_tokens = self.tokenizer( + rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True + ) + prompt_tokens = self.tokenizer( + prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=batch["rejected_labels"] + ) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=batch["chosen_labels"] + ) + + return batch + + @contextmanager + def null_ref_context(self): + with self.accelerator.unwrap_model( + self.model + ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + + def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: + compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + with torch.no_grad(), compte_ref_context_manager(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.model, padded_batch) + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _ + ) = self.concatenated_forward(self.ref_model, padded_batch) + + return reference_chosen_logps, reference_rejected_logps + + @staticmethod + def concatenated_inputs( + batch: Dict[str, Union[List, torch.LongTensor]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + ) -> Dict[str, torch.LongTensor]: + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + chosen_probs: Union[torch.FloatTensor, None] = None, + chosen_probs_win: Union[torch.FloatTensor, None] = None, + chosen_probs_lose: Union[torch.FloatTensor, None] = None, + reference_free: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + if reference_free: + ref_logratios = 0 + else: + ref_logratios = reference_chosen_logps - reference_rejected_logps + + pi_logratios = pi_logratios.to(self.accelerator.device) + ref_logratios = ref_logratios.to(self.accelerator.device) + logits = pi_logratios - ref_logratios + + logits_w = policy_chosen_logps - reference_chosen_logps + logits_l = policy_rejected_logps - reference_rejected_logps + + if self.loss_type == "sigmoid": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + elif self.loss_type == "ipo": + losses = (logits - 1 / (2 * self.beta)) ** 2 + elif self.loss_type == "sppo": + loss_w = (logits_w - (1 / self.beta)*(chosen_probs_win - 0.5)) ** 2 + loss_l = (logits_l - (1 / self.beta)*(chosen_probs_lose - 0.5)) ** 2 + losses = (loss_w + loss_l)/2 + elif self.loss_type == "sppo_single": + loss_w = (logits_w - (1 / self.beta)*(chosen_probs - 0.5)) ** 2 + loss_l = (logits_l + (1 / self.beta)*(chosen_probs - 0.5)) ** 2 + losses = (loss_w + loss_l)/2 + elif self.loss_type == "kto_pair": + chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) + rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + losses = torch.cat( + ( + 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)), + 1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)), + ), + 0, + ) + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']" + ) + + chosen_rewards = ( + self.beta + * ( + policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device) + ).detach() + ) + rejected_rewards = ( + self.beta + * ( + policy_rejected_logps.to(self.accelerator.device) + - reference_rejected_logps.to(self.accelerator.device) + ).detach() + ) + + return losses, chosen_rewards, rejected_rewards + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + + labels[labels == label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"].to(torch.int64), + "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), + } + if self.is_encoder_decoder + else {} + ) + + concatenated_batch["concatenated_input_ids"] = concatenated_batch["concatenated_input_ids"].to(torch.int32) + concatenated_batch["concatenated_attention_mask"] = concatenated_batch["concatenated_attention_mask"].to(torch.int32) + concatenated_batch["concatenated_labels"] = concatenated_batch["concatenated_labels"].to(torch.int64) + + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + **model_kwargs, + ).logits + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + len_chosen = batch["chosen_labels"].shape[0] + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def get_batch_loss_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + metrics = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(model, batch) + + chosen_probs = torch.tensor(batch["chosen_probs"], dtype=float, device=policy_chosen_logps.device) + chosen_probs_win = torch.tensor(batch["chosen_probs_win"], dtype=float, device=policy_chosen_logps.device) + chosen_probs_lose = torch.tensor(batch["chosen_probs_lose"], dtype=float, device=policy_chosen_logps.device) + + if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch: + reference_chosen_logps = batch["reference_chosen_logps"] + reference_rejected_logps = batch["reference_rejected_logps"] + else: + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.model, batch) + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _ + ) = self.concatenated_forward(self.ref_model, batch) + + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + chosen_probs, + chosen_probs_win, + chosen_probs_lose, + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() + + return losses.mean(), metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + if not self.use_dpo_data_collator: + warnings.warn( + "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + + compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + with compute_loss_context_manager(): + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast + + with generate_context_manager(): + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + + if "reference_output" in batch: + reference_output = batch["reference_output"] + else: + if self.ref_model is None: + with self.null_ref_context(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id) + policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) + + reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id) + reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) + + return policy_output_decoded, reference_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ): + if not self.use_dpo_data_collator: + warnings.warn( + "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + with torch.no_grad(), prediction_context_manager(): + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys) + logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + if self.generate_during_eval: + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) + + self.log( + { + "game_log": wandb.Table( + columns=["Prompt", "Policy", "Ref Model"], + rows=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip( + random_batch["prompt"], policy_output_decoded, ref_output_decoded + ) + ], + ) + } + ) + self.state.log_history.pop() + + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: Dict[str, float]) -> None: + train_eval = "train" if "loss" in logs else "eval" + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs) + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) \ No newline at end of file diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index 585696e29a..1740c0474e 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -42,6 +42,26 @@ def transform_fn(sample): return transform_fn +def sppo_argilla_chat( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + for argilla/dpo-mix-7k conversations + """ + + def transform_fn(sample): + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" + sample["chosen_probs"] = sample['chosen_probs'] + sample["chosen_probs_lose"] = sample['chosen_probs_lose'] + sample["chosen_probs_win"] = sample['chosen_probs_win'] + return sample + + return transform_fn def icr( cfg, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3cac4f8391..96fa76a683 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -166,6 +166,7 @@ class RLType(str, Enum): dpo = "dpo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name + sppo = "sppo" # pylint: disable=invalid-name kto = "kto" # pylint: disable=invalid-name diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a16baaae0f..029067beae 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -427,7 +427,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "orpo", "kto"]: + if cfg.rl in ["dpo", "ipo", "orpo", "kto", "sppo"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] From 32ff7ea65580947e7618188da6774e468247bf15 Mon Sep 17 00:00:00 2001 From: Lord Kayky Ramos Date: Fri, 12 Jul 2024 10:23:20 -0300 Subject: [PATCH 2/2] Update examples/llama-3/sppo-qlora-8b.yml Co-authored-by: Wing Lian --- examples/llama-3/sppo-qlora-8b.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama-3/sppo-qlora-8b.yml b/examples/llama-3/sppo-qlora-8b.yml index e4454e619a..8f10f44fb0 100644 --- a/examples/llama-3/sppo-qlora-8b.yml +++ b/examples/llama-3/sppo-qlora-8b.yml @@ -41,7 +41,7 @@ wandb_log_model: gradient_accumulation_steps: 8 micro_batch_size: 1 num_epochs: 1 -optimizer: paged_adamw_32bit +optimizer: paged_adamw_8bit lr_scheduler: cosine learning_rate: 2.0e-4