From cf6b9836ac9cbeb29bd1fe0a209a2c38c98f4d51 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 3 May 2024 08:41:59 -0400 Subject: [PATCH] update for sppo --- docs/config.qmd | 2 +- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 8 ++++---- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 2 +- src/axolotl/utils/models.py | 2 +- src/axolotl/utils/trainer.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index bb3158e505..568de5cb1a 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -138,7 +138,7 @@ test_datasets: data_files: - /workspace/data/eval.jsonl -# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo' +# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard' rl: # Saves the desired chat template to the tokenizer_config.json for easier inferencing diff --git a/requirements.txt b/requirements.txt index 4ec2aec89c..39c1623c3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl @ git+https://github.com/huggingface/trl.git@7075cec94df1a0c5be90e75214e996efaf9a6c0b +trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 8049fd84d3..0315ca4049 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1520,7 +1520,7 @@ def build_training_arguments(self, total_num_steps): training_args_cls = TrainingArguments if self.cfg.rl == "orpo": training_args_cls = ORPOConfig - elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: + elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]: training_args_cls = DPOConfig training_args = training_args_cls( @@ -1548,8 +1548,8 @@ def build(self, total_num_steps): dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing elif self.cfg.rl == "kto_pair": dpo_trainer_kwargs["loss_type"] = "kto_pair" - elif self.cfg.rl == "sppo": - dpo_trainer_kwargs["loss_type"] = "sppo" + elif self.cfg.rl == "sppo_hard": + dpo_trainer_kwargs["loss_type"] = "sppo_hard" if self.eval_dataset: dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: @@ -1558,7 +1558,7 @@ def build(self, total_num_steps): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: + if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]: trainer_cls = AxolotlDPOTrainer dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 trainer_cls_args = [self.model, self.model_ref] diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1d81b141bf..6c5283fb04 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -133,7 +133,7 @@ class RLType(str, Enum): ipo = "ipo" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name - sppo = "sppo" # pylint: disable=invalid-name + sppo = "sppo_hard" # pylint: disable=invalid-name class ChatTemplate(str, Enum): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index f0ae55a739..fc8a67acf2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -791,7 +791,7 @@ def load_model( # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config if ( cfg.adapter - and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"] + and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"] and not cfg.merge_lora ): _, lora_config = load_lora(model, cfg, inference=False, config_only=True) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index fe1f6e0bd1..1a0e550103 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -438,7 +438,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo"]: + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2]