diff --git a/src/tests/trainers/test_trainers.py b/src/tests/trainers/test_trainers.py index 69399021..c36f4fcd 100644 --- a/src/tests/trainers/test_trainers.py +++ b/src/tests/trainers/test_trainers.py @@ -1845,7 +1845,7 @@ def test_seq2seq(self, create_datadreamer, mocker): } trainer_path = cast(str, trainer._output_folder_path) with open(os.path.join(trainer_path, "fingerprint.json"), "r") as f: - assert json.load(f) == "124679f0dc66f8d8" + assert json.load(f) == "172ce12f7687547a" assert train_result is trainer assert ( type(get_orig_model(trainer.model)).__name__ @@ -1976,7 +1976,7 @@ def test_causal(self, create_datadreamer, mocker): } # fmt: skip trainer_path = cast(str, trainer._output_folder_path) with open(os.path.join(trainer_path, "fingerprint.json"), "r") as f: - assert json.load(f) == "ec6f56c1bad25250" + assert json.load(f) == "8e953e25b69f8d82" assert train_result is trainer assert type(get_orig_model(trainer.model)).__name__ == "GPT2LMHeadModel" assert trainer.model_path == os.path.join(trainer_path, "_model") @@ -2060,7 +2060,7 @@ def test_peft(self, create_datadreamer, mocker): ) trainer_path = cast(str, trainer._output_folder_path) with open(os.path.join(trainer_path, "fingerprint.json"), "r") as f: - assert json.load(f) == "bfeb4e3d409fa1f6" + assert json.load(f) == "9682924c97dd492c" assert train_result is trainer assert ( type(get_orig_model(trainer.model)).__name__ == "PeftModelForCausalLM" diff --git a/src/trainers/_dpo_helper.py b/src/trainers/_dpo_helper.py new file mode 100644 index 00000000..d3baed69 --- /dev/null +++ b/src/trainers/_dpo_helper.py @@ -0,0 +1,292 @@ +# type: ignore +# flake8: noqa + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import torch +from torch.nn.utils.rnn import pad_sequence + +from ..utils.import_utils import ignore_transformers_warnings + +with ignore_transformers_warnings(): + from transformers import PreTrainedModel, PreTrainedTokenizerBase + + +@dataclass +class DPODataCollatorWithPadding: # pragma: no cover + r""" + DPO DataCollator class that pads the inputs to the maximum length of the batch. + Args: + tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used for encoding the data. + model (Optional[`PreTrainedModel`]): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids*. + padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): + padding_strategy to pass to the tokenizer. + max_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the sequence to be processed. + max_prompt_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the prompt to be processed. + label_pad_token_id (`int`, defaults to -100): + The label used for masking. + padding_value (`int`, defaults to 0): + The value used for padding. + is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): + Whether or not you model has an encoder_decoder architecture. + max_target_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the target to be processed. Only useful for encoder-decoder architectures. + truncation_mode: (`str`, defaults to "keep_end"): + The truncation mode to use when truncating the prompt. + """ + tokenizer: PreTrainedTokenizerBase + model: Optional[PreTrainedModel] = None + padding: Union[bool, str] = True + max_length: Optional[int] = None + max_prompt_length: Optional[int] = None + label_pad_token_id: int = -100 + padding_value: int = 0 + truncation_mode: str = "keep_end" + is_encoder_decoder: Optional[bool] = False + max_target_length: Optional[int] = None + + def tokenize_batch_element( + self, + prompt: str, + chosen: str, + rejected: str, + ) -> Dict: + """Tokenize a single batch element. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + chosen or prompt + rejected responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to + the sum of the length of the prompt and the chosen/rejected response, with + label_pad_token_id for the prompt tokens. + """ + batch = {} + + if not self.is_encoder_decoder: + chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) + rejected_tokens = self.tokenizer(rejected, add_special_tokens=False) + prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) + + eos_token_id = self.tokenizer.eos_token_id + # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0) + eos_indices_prompt = [ + i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id + ] + # attention mask these indices to eos_token_id + new_attention_mask = [ + 0 if i in eos_indices_prompt else p + for i, p in enumerate(prompt_tokens["attention_mask"]) + ] + prompt_tokens["attention_mask"] = new_attention_mask + + # do the same for chosen and rejected + eos_indices_chosen = [ + i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id + ] + new_attention_mask_c = [ + 0 if i in eos_indices_chosen else p + for i, p in enumerate(chosen_tokens["attention_mask"]) + ] + chosen_tokens["attention_mask"] = new_attention_mask_c + + eos_indices_rejected = [ + i + for i, x in enumerate(rejected_tokens["input_ids"]) + if x == eos_token_id + ] + new_attention_mask_r = [ + 0 if i in eos_indices_rejected else p + for i, p in enumerate(rejected_tokens["attention_mask"]) + ] + rejected_tokens["attention_mask"] = new_attention_mask_r + + # add EOS token to end of prompt + chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) + chosen_tokens["attention_mask"].append(1) + + rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) + rejected_tokens["attention_mask"].append(1) + + longer_response_length = max( + len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]) + ) + + # if combined sequence is too long, truncate the prompt + if ( + len(prompt_tokens["input_ids"]) + longer_response_length + > self.max_length + ): + if self.truncation_mode == "keep_start": + prompt_tokens = { + k: v[: self.max_prompt_length] for k, v in prompt_tokens.items() + } + elif self.truncation_mode == "keep_end": + prompt_tokens = { + k: v[-self.max_prompt_length :] + for k, v in prompt_tokens.items() + } + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + if ( + len(prompt_tokens["input_ids"]) + longer_response_length + > self.max_length + ): + chosen_tokens = { + k: v[: self.max_length - self.max_prompt_length] + for k, v in chosen_tokens.items() + } + rejected_tokens = { + k: v[: self.max_length - self.max_prompt_length] + for k, v in rejected_tokens.items() + } + + # Create labels + chosen_sequence_tokens = { + k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens + } + rejected_sequence_tokens = { + k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [ + self.label_pad_token_id + ] * len(prompt_tokens["input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][ + : + ] + rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [ + self.label_pad_token_id + ] * len(prompt_tokens["input_ids"]) + + for k, toks in { + "chosen": chosen_sequence_tokens, + "rejected": rejected_sequence_tokens, + "prompt": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}_{type_key}"] = tokens + + else: + chosen_tokens = self.tokenizer( + chosen, + truncation=True, + max_length=self.max_target_length, + add_special_tokens=True, + ) + rejected_tokens = self.tokenizer( + rejected, + truncation=True, + max_length=self.max_target_length, + add_special_tokens=True, + ) + prompt_tokens = self.tokenizer( + prompt, + truncation=True, + max_length=self.max_prompt_length, + add_special_tokens=True, + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if self.model is not None and hasattr( + self.model, "prepare_decoder_input_ids_from_labels" + ): + batch[ + "rejected_decoder_input_ids" + ] = self.model.prepare_decoder_input_ids_from_labels( + labels=batch["rejected_labels"] + ) + batch[ + "chosen_decoder_input_ids" + ] = self.model.prepare_decoder_input_ids_from_labels( + labels=batch["chosen_labels"] + ) + + batch["prompt"] = prompt + batch["chosen"] = prompt + chosen + batch["rejected"] = prompt + rejected + batch["chosen_response_only"] = chosen + batch["rejected_response_only"] = rejected + + return batch + + def collate(self, batch): + # first, pad everything to the same length + padded_batch = {} + for k in batch[0].keys(): + if ( + k.endswith("_input_ids") + or k.endswith("_attention_mask") + or k.endswith("_labels") + ): + if self.is_encoder_decoder: + to_pad = [torch.LongTensor(ex[k]) for ex in batch] + + if (k.startswith("prompt")) and (k.endswith("input_ids")): + padding_value = self.tokenizer.pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif ( + (k.startswith("chosen")) + or (k.startswith("rejected")) + or ("decoder" in k) + ): + padding_value = self.label_pad_token_id + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence( + to_pad, batch_first=True, padding_value=padding_value + ) + else: + # adapted from https://stackoverflow.com/questions/73256206 + if "prompt" in k: + to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] + else: + to_pad = [torch.LongTensor(ex[k]) for ex in batch] + if k.endswith("_input_ids"): + padding_value = self.tokenizer.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + elif k.endswith("_attention_mask"): + padding_value = self.padding_value + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + padded_batch[k] = pad_sequence( + to_pad, batch_first=True, padding_value=padding_value + ) + # for the prompt, flip back so padding is on left side + if "prompt" in k: + padded_batch[k] = padded_batch[k].flip(dims=[1]) + else: + padded_batch[k] = [ex[k] for ex in batch] + + return padded_batch + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + tokenized_batch = [] + + for feature in features: + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + batch_element = self.tokenize_batch_element(prompt, chosen, rejected) + tokenized_batch.append(batch_element) + + # return collated batch + return self.collate(tokenized_batch) diff --git a/src/trainers/_train_hf_base.py b/src/trainers/_train_hf_base.py index 35a241a6..605f9bc7 100644 --- a/src/trainers/_train_hf_base.py +++ b/src/trainers/_train_hf_base.py @@ -55,7 +55,7 @@ get_tokenizer, is_encoder_decoder, ) -from ..utils.import_utils import ignore_transformers_warnings, ignore_trl_warnings +from ..utils.import_utils import ignore_transformers_warnings from .trainer import Trainer as DataDreamerTrainer with ignore_transformers_warnings(): @@ -552,8 +552,7 @@ def prepare_for_reward_pairs(row): else 120, ).output.dataset elif dpo: - with ignore_trl_warnings(): - from trl.trainer.utils import DPODataCollatorWithPadding + from ._dpo_helper import DPODataCollatorWithPadding # Get data collator data_collator = DPODataCollatorWithPadding( diff --git a/src/trainers/train_hf_dpo.py b/src/trainers/train_hf_dpo.py index 11d9429f..8bdd900f 100644 --- a/src/trainers/train_hf_dpo.py +++ b/src/trainers/train_hf_dpo.py @@ -150,6 +150,10 @@ def _train( # type:ignore[override] dpo=True, ) + # We have already tokenized the dataset, so don't let DPOTrainer try to tokenize. + train_dataset.map = lambda *args, **kwargs: train_dataset + validation_dataset.map = lambda *args, **kwargs: validation_dataset + # Prepare compute metrics compute_metrics = kwargs.pop("compute_metrics", None) @@ -369,7 +373,7 @@ def train( # type:ignore[override] warmup_steps: int = 0, neftune_noise_alpha: None | float = None, dpo_beta: float = 0.1, - loss_type: str = "sigmoid", + loss_type: str = "kto_pair", disable_dropout: bool = True, seed: int = 42, **kwargs,