Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements SPPO Alignment Algoritm #1735

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

kaykyr
Copy link

@kaykyr kaykyr commented Jul 11, 2024

Implements SPPO Alignment Algorithm

Description

This pull request implements the Self-Play Preference Optimization (SPPO) algorithm for language model alignment. The SPPO algorithm, as described in the paper "Self-Play Preference Optimization for Language Model Alignment" (available at https://arxiv.org/abs/2405.00675), uses a self-play mechanism to optimize language models based on preference probabilities. This implementation leverages the code from the original repository at https://github.com/uclaml/SPPO and integrates it into the Axolotl framework.

Motivation and Context

This change is required to improve the alignment of language models with human preferences, addressing issues of reliability, safety, and ethical considerations in language model outputs. The SPPO algorithm provides a more flexible and accurate method for preference optimization compared to traditional reinforcement learning approaches.

How has this been tested?

The implementation has been tested using a variety of prompts from the UltraFeedback dataset, evaluating the model's performance on AlpacaEval 2.0 and MT-Bench. The tests involved assessing the log-likelihood of chosen responses and comparing the model's win rates against state-of-the-art models, ensuring that the changes do not adversely affect other areas of the codebase.

Screenshots (if appropriate)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Social Handles (Optional)

GitHub: @kaykyr
HuggingFace: https://huggingface.co/kaykyramos
Discord: kaykyramos

Copy link
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @kaykyr, Thanks for submitting this technique. I'd love to see this integrated into axolotl, but my main concern is the amount of duplicated code we're going to have to maintain. I'm happy to help refactor the pieces in the trainer_builder, but I think it would be ideal if we could extract the necessary SPPO changes from DPOTrainer so we have a smaller footprint to maintain.

re: tests, would be good to have some tests to spot check the functionality. I'm happy to help with this as well, where we setup some e2e tests that run a small model for about 10-20 steps to verify that the trainer works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably go in src/axolotl/core/trainers

examples/llama-3/sppo-qlora-8b.yml Outdated Show resolved Hide resolved
Comment on lines +917 to +945
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()

opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)

loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)

return self.optimizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it might be worth extracting this as a AxolotlCreateOptimizerMixin and then including it in both here and the AxolotlTrainer


return self.optimizer

@wraps(DPOTrainer.push_to_hub)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is DPOTrainer correct for this?

Comment on lines +948 to +964
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)

return super().push_to_hub(*args, **kwargs)

def tokenize_row(
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
) -> Dict:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more duplicated code that makes me think we should be extracting this into a Mixin.

if is_deepspeed_available():
import deepspeed

class SPPOTrainer(Trainer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty big class that seems to duplicate a lot of code from likely the DPOTrainer I would assume? Would it make more sense to extend the DPOTrainer and just implement the necessary changes?

Comment on lines +45 to +64
def sppo_argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""

def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
sample["chosen_probs"] = sample['chosen_probs']
sample["chosen_probs_lose"] = sample['chosen_probs_lose']
sample["chosen_probs_win"] = sample['chosen_probs_win']
return sample

return transform_fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because this is in the dpo path, which iirc, it is loaded based on the rl: ... setting, I'm not sure that this will load as expected.

@kaykyr
Copy link
Author

kaykyr commented Jul 12, 2024

hi @kaykyr, Thanks for submitting this technique. I'd love to see this integrated into axolotl, but my main concern is the amount of duplicated code we're going to have to maintain. I'm happy to help refactor the pieces in the trainer_builder, but I think it would be ideal if we could extract the necessary SPPO changes from DPOTrainer so we have a smaller footprint to maintain.

re: tests, would be good to have some tests to spot check the functionality. I'm happy to help with this as well, where we setup some e2e tests that run a small model for about 10-20 steps to verify that the trainer works.

Hey @winglian, I'll do my best to submit a better pull request doing a better approach to SPPO Integration.

@kaykyr
Copy link
Author

kaykyr commented Jul 12, 2024

hi @kaykyr, Thanks for submitting this technique. I'd love to see this integrated into axolotl, but my main concern is the amount of duplicated code we're going to have to maintain. I'm happy to help refactor the pieces in the trainer_builder, but I think it would be ideal if we could extract the necessary SPPO changes from DPOTrainer so we have a smaller footprint to maintain.

re: tests, would be good to have some tests to spot check the functionality. I'm happy to help with this as well, where we setup some e2e tests that run a small model for about 10-20 steps to verify that the trainer works.

I am also running 3 iterations and I'll upload the result models to hugging face for comparison... At this momento I am running the iter2 on my homelab.

{'loss': 0.4766, 'grad_norm': 5.125, 'learning_rate': 0.00015594912061278626, 'rewards/chosen': -1.9593877792358398, 'rewards/rejected': -3.4484448432922363, 'rewards/accuracies': 0.5, 'rewards/margins': 1.4890570640563965, 'logps/rejected': -156.06246948242188, 'logps/chosen': -163.50051879882812, 'logits/rejected': -0.08261816203594208, 'logits/chosen': 0.01163027435541153, 'epoch': 1.0}
{'loss': 0.1187, 'grad_norm': 0.85546875, 'learning_rate': 0.000155500908021347, 'rewards/chosen': -1.4073667526245117, 'rewards/rejected': -7.409327983856201, 'rewards/accuracies': 1.0, 'rewards/margins': 6.0019612312316895, 'logps/rejected': -302.47320556640625, 'logps/chosen': -217.1146697998047, 'logits/rejected': -0.17893444001674652, 'logits/chosen': -0.3086361289024353, 'epoch': 1.01}
{'loss': 0.2105, 'grad_norm': 3.84375, 'learning_rate': 0.00015505107827058036, 'rewards/chosen': -2.149808645248413, 'rewards/rejected': -5.619155406951904, 'rewards/accuracies': 0.75, 'rewards/margins': 3.469346761703491, 'logps/rejected': -212.09442138671875, 'logps/chosen': -163.01132202148438, 'logits/rejected': -0.03775382041931152, 'logits/chosen': -0.1259421706199646, 'epoch': 1.01}
{'loss': 0.1979, 'grad_norm': 1.2890625, 'learning_rate': 0.00015459964446741382, 'rewards/chosen': -0.6550925374031067, 'rewards/rejected': -4.351650714874268, 'rewards/accuracies': 0.625, 'rewards/margins': 3.6965579986572266, 'logps/rejected': -196.5653839111328, 'logps/chosen': -149.3184814453125, 'logits/rejected': -0.23472216725349426, 'logits/chosen': -0.27307409048080444, 'epoch': 1.02}
{'loss': 0.2405, 'grad_norm': 1.984375, 'learning_rate': 0.00015414661976551302, 'rewards/chosen': -0.46041208505630493, 'rewards/rejected': -5.224635601043701, 'rewards/accuracies': 0.75, 'rewards/margins': 4.764223098754883, 'logps/rejected': -244.72647094726562, 'logps/chosen': -190.26235961914062, 'logits/rejected': -0.25164929032325745, 'logits/chosen': -0.09058046340942383, 'epoch': 1.02}
{'loss': 0.0668, 'grad_norm': 1.984375, 'learning_rate': 0.0001536920173648984, 'rewards/chosen': -2.3507869243621826, 'rewards/rejected': -6.221563816070557, 'rewards/accuracies': 1.0, 'rewards/margins': 3.870777130126953, 'logps/rejected': -314.01568603515625, 'logps/chosen': -270.7308044433594, 'logits/rejected': 0.017346393316984177, 'logits/chosen': 0.05637218803167343, 'epoch': 1.03}
{'loss': 0.227, 'grad_norm': 1.8125, 'learning_rate': 0.0001532358505115607, 'rewards/chosen': -0.6020192503929138, 'rewards/rejected': -5.3337531089782715, 'rewards/accuracies': 0.875, 'rewards/margins': 4.731733798980713, 'logps/rejected': -214.8163299560547, 'logps/chosen': -182.3211212158203, 'logits/rejected': -0.021296918392181396, 'logits/chosen': 0.01770481839776039, 'epoch': 1.03}
{'loss': 0.1773, 'grad_norm': 2.6875, 'learning_rate': 0.00015277813249707487, 'rewards/chosen': -1.7499217987060547, 'rewards/rejected': -5.255713939666748, 'rewards/accuracies': 1.0, 'rewards/margins': 3.5057921409606934, 'logps/rejected': -310.7301330566406, 'logps/chosen': -271.0190734863281, 'logits/rejected': 0.045894794166088104, 'logits/chosen': -0.04848558083176613, 'epoch': 1.04}
{'loss': 0.4042, 'grad_norm': 3.546875, 'learning_rate': 0.000152318876658213, 'rewards/chosen': -1.4743281602859497, 'rewards/rejected': -4.679446697235107, 'rewards/accuracies': 0.875, 'rewards/margins': 3.205118179321289, 'logps/rejected': -271.888671875, 'logps/chosen': -239.9329833984375, 'logits/rejected': 0.010095290839672089, 'logits/chosen': -0.08229245990514755, 'epoch': 1.04}
{'loss': 0.2753, 'grad_norm': 2.671875, 'learning_rate': 0.0001518580963765555, 'rewards/chosen': -3.2833800315856934, 'rewards/rejected': -6.391260147094727, 'rewards/accuracies': 0.875, 'rewards/margins': 3.1078805923461914, 'logps/rejected': -283.3193054199219, 'logps/chosen': -248.41119384765625, 'logits/rejected': -0.036375753581523895, 'logits/chosen': -0.032573096454143524, 'epoch': 1.05}
{'loss': 0.0712, 'grad_norm': 1.40625, 'learning_rate': 0.00015139580507810119, 'rewards/chosen': -0.22441568970680237, 'rewards/rejected': -4.1824116706848145, 'rewards/accuracies': 1.0, 'rewards/margins': 3.957995891571045, 'logps/rejected': -218.41976928710938, 'logps/chosen': -175.13992309570312, 'logits/rejected': 0.10275314003229141, 'logits/chosen': 0.039625994861125946, 'epoch': 1.05}
{'loss': 0.066, 'grad_norm': 1.984375, 'learning_rate': 0.00015093201623287631, 'rewards/chosen': -1.5608439445495605, 'rewards/rejected': -7.547214508056641, 'rewards/accuracies': 1.0, 'rewards/margins': 5.98637056350708, 'logps/rejected': -326.3816223144531, 'logps/chosen': -270.4262390136719, 'logits/rejected': -0.08117158710956573, 'logits/chosen': -0.03967729210853577, 'epoch': 1.06}
{'loss': 0.1853, 'grad_norm': 2.515625, 'learning_rate': 0.0001504667433545419, 'rewards/chosen': -0.02940535545349121, 'rewards/rejected': -5.676110744476318, 'rewards/accuracies': 0.875, 'rewards/margins': 5.646705150604248, 'logps/rejected': -262.2330017089844, 'logps/chosen': -205.32958984375, 'logits/rejected': -0.13373667001724243, 'logits/chosen': -0.18218691647052765, 'epoch': 1.06}
{'loss': 0.1211, 'grad_norm': 1.2265625, 'learning_rate': 0.00015000000000000001, 'rewards/chosen': -1.7262533903121948, 'rewards/rejected': -5.780290603637695, 'rewards/accuracies': 1.0, 'rewards/margins': 4.054037094116211, 'logps/rejected': -259.7177734375, 'logps/chosen': -185.06295776367188, 'logits/rejected': -0.01202734187245369, 'logits/chosen': -0.19560036063194275, 'epoch': 1.07}
{'loss': 0.1652, 'grad_norm': 1.796875, 'learning_rate': 0.00014953179976899878, 'rewards/chosen': -2.2642948627471924, 'rewards/rejected': -6.6951165199279785, 'rewards/accuracies': 0.75, 'rewards/margins': 4.430821418762207, 'logps/rejected': -271.29461669921875, 'logps/chosen': -222.35409545898438, 'logits/rejected': -0.15404176712036133, 'logits/chosen': -0.19560746848583221, 'epoch': 1.07}
 36%|█████████████████████████████████████████████████████████████▍                                                                                                              | 215/602 [1:53:26<3:32:56, 33.01s/it]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants