From 13124236336ad17ecffe43143b9c072849c4fc2c Mon Sep 17 00:00:00 2001 From: Ajay Patel Date: Tue, 30 Apr 2024 01:00:29 -0400 Subject: [PATCH] Improve FSDP + QLora (#25) [release] * Fix simple issues * Default eval_accumulation_steps * Filter model warnings * Support FSDP activation checkpointing * Refactor training code * Reduce CPU RAM usage by compute_metrics * Fix tests * Fix tests * Update RAM * Add test_StrWithSeed --- pyproject.toml | 2 +- scripts/.cluster/slurm/_sbatch_config.sh | 2 +- src/_cachable/_cachable.py | 10 +- .../sentence_transformers_embedder.py | 4 + src/llms/hf_transformers.py | 13 +- src/llms/petals.py | 4 + .../hf_classification_task_model.py | 4 + src/tests/llms/test_llms.py | 22 +- src/tests/trainers/test_distributed.py | 5 +- src/tests/trainers/test_trainers.py | 2 +- src/tests/utils/test_device_utils.py | 5 +- src/trainers/_train_hf_base.py | 930 +---------------- src/trainers/train_hf_classifier.py | 27 +- src/trainers/train_hf_dpo.py | 22 +- src/trainers/train_hf_finetune.py | 44 +- src/trainers/train_hf_ppo.py | 17 +- src/trainers/train_hf_reward_model.py | 38 +- src/trainers/train_sentence_transformer.py | 38 +- src/trainers/train_setfit_classifier.py | 30 +- src/trainers/trainer.py | 25 +- src/utils/device_utils.py | 1 + src/utils/distributed_utils.py | 35 +- src/utils/hf_model_utils.py | 22 + src/utils/hf_training_utils.py | 973 ++++++++++++++++++ src/utils/import_utils.py | 19 - 25 files changed, 1235 insertions(+), 1059 deletions(-) create mode 100644 src/utils/hf_training_utils.py diff --git a/pyproject.toml b/pyproject.toml index f78c1ff7..0f2ab30d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "DataDreamer" -version = "0.31.0" +version = "0.32.0" description = "Prompt. Generate Synthetic Data. Train & Align Models." license = "MIT" authors= [ diff --git a/scripts/.cluster/slurm/_sbatch_config.sh b/scripts/.cluster/slurm/_sbatch_config.sh index ebed403d..75ee5a86 100755 --- a/scripts/.cluster/slurm/_sbatch_config.sh +++ b/scripts/.cluster/slurm/_sbatch_config.sh @@ -5,7 +5,7 @@ #SBATCH --output=.cluster/slurm/.last_job/submission.out #SBATCH --ntasks 1 #SBATCH --cpus-per-task 16 -#SBATCH --mem=10G +#SBATCH --mem=30G #SBATCH --gpus=2 # Source the user's bashrc diff --git a/src/_cachable/_cachable.py b/src/_cachable/_cachable.py index 7122d8ce..f85f4fa5 100644 --- a/src/_cachable/_cachable.py +++ b/src/_cachable/_cachable.py @@ -60,7 +60,7 @@ def _notify_adaptive_batch_sizing(model_logger: Logger, progress_state: dict[str class _StrWithSeed(str): seed: Any - def __new__(cls, value: str, seed: "Any | _StrWithSeed"): + def __new__(cls, value: str, seed: "Any | _StrWithSeed" = None): obj = str.__new__(cls, value) obj.seed = seed.seed if isinstance(seed, _StrWithSeed) else seed return obj @@ -75,6 +75,14 @@ def __eq__(self, __value: object) -> bool: def __hash__(self): return hash((self.seed, str(self))) + def __getstate__(self): + state = {"str": str(self), "seed": self.seed} + + return state + + def __setstate__(self, state): + self.seed = state["seed"] + @staticmethod def total_per_input_seeds(inputs: list["str | _StrWithSeed"]) -> int: return sum( diff --git a/src/embedders/sentence_transformers_embedder.py b/src/embedders/sentence_transformers_embedder.py index 6feb2aa2..1e345ab7 100644 --- a/src/embedders/sentence_transformers_embedder.py +++ b/src/embedders/sentence_transformers_embedder.py @@ -19,6 +19,7 @@ ) from ..utils.hf_model_utils import ( convert_dtype, + filter_model_warnings, get_model_max_context_length, get_tokenizer, ) @@ -122,6 +123,9 @@ def model(self) -> SentenceTransformer: # torch._dynamo.config.suppress_errors = True # model = torch.compile(model) + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial( diff --git a/src/llms/hf_transformers.py b/src/llms/hf_transformers.py index 5c64030f..82151ef4 100644 --- a/src/llms/hf_transformers.py +++ b/src/llms/hf_transformers.py @@ -21,6 +21,7 @@ HF_TRANSFORMERS_CITATION, PEFT_CITATION, convert_dtype, + filter_model_warnings, get_attn_implementation, get_config, get_model_max_context_length, @@ -273,6 +274,9 @@ def model(self) -> PreTrainedModel: torch._dynamo.config.suppress_errors = True model = torch.compile(model) + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial( @@ -323,15 +327,6 @@ def count_tokens(self, value: str) -> int: Returns: The number of tokens in the string. """ - pass - """_summary_ - - Args: - value (_type_): _description_ - - Returns: - _type_: _description_ - """ return len(self.tokenizer.encode(value)) @torch.no_grad() diff --git a/src/llms/petals.py b/src/llms/petals.py index 0e365014..66a7a3bb 100644 --- a/src/llms/petals.py +++ b/src/llms/petals.py @@ -11,6 +11,7 @@ from ..utils.arg_utils import AUTO, Default from ..utils.background_utils import RunIfTimeout from ..utils.fs_utils import safe_fn +from ..utils.hf_model_utils import filter_model_warnings from ..utils.import_utils import ignore_hivemind_warnings, ignore_transformers_warnings from .hf_transformers import HFTransformers @@ -161,6 +162,9 @@ def model(self) -> PreTrainedModel: # torch._dynamo.config.suppress_errors = True # model = torch.compile(model) + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial( diff --git a/src/task_models/hf_classification_task_model.py b/src/task_models/hf_classification_task_model.py index cab48263..cfc2871e 100644 --- a/src/task_models/hf_classification_task_model.py +++ b/src/task_models/hf_classification_task_model.py @@ -16,6 +16,7 @@ HF_TRANSFORMERS_CITATION, PEFT_CITATION, convert_dtype, + filter_model_warnings, get_config, get_model_max_context_length, get_tokenizer, @@ -152,6 +153,9 @@ def model(self) -> PreTrainedModel: # torch._dynamo.config.suppress_errors = True # model = torch.compile(model) + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial( diff --git a/src/tests/llms/test_llms.py b/src/tests/llms/test_llms.py index 0db4ab8f..5ec8d05d 100644 --- a/src/tests/llms/test_llms.py +++ b/src/tests/llms/test_llms.py @@ -12,6 +12,7 @@ from time import sleep from types import GeneratorType +import dill import psutil import pytest import torch @@ -19,7 +20,7 @@ from sortedcontainers import SortedDict from ... import DataDreamer -from ..._cachable._cachable import _is_primitive +from ..._cachable._cachable import _is_primitive, _StrWithSeed from ...llms import ( AI21, VLLM, @@ -338,6 +339,25 @@ def test_is_primitive(self): assert _is_primitive({"foo": 5}) assert not _is_primitive({"foo": object()}) + def test_StrWithSeed(self): + seed_a = _StrWithSeed("hello", seed=1) + seed_b = _StrWithSeed("hello", seed=2) + seed_c = _StrWithSeed("hello", seed=1) + assert ( + isinstance(seed_a, str) + and isinstance(seed_b, str) + and isinstance(seed_c, str) + ) + assert seed_a.seed == 1 + assert seed_b.seed == 2 + assert seed_c.seed == 1 + assert str(seed_a) == "hello" + assert str(seed_b) == "hello" + assert str(seed_c) == "hello" + assert hash(seed_a) != hash(seed_b) + assert hash(seed_a) == hash(seed_c) + assert hash(seed_a) == hash(dill.loads(dill.dumps(seed_c))) + def test_check_temperature_and_top_p(self): assert _check_temperature_and_top_p( temperature=0.3, diff --git a/src/tests/trainers/test_distributed.py b/src/tests/trainers/test_distributed.py index 160e5d07..42b220b3 100644 --- a/src/tests/trainers/test_distributed.py +++ b/src/tests/trainers/test_distributed.py @@ -17,9 +17,9 @@ TrainHFPPO, TrainSentenceTransformer, ) -from ...trainers._train_hf_base import CustomDataCollatorWithPadding from ...utils.arg_utils import AUTO from ...utils.hf_model_utils import get_orig_model, is_bnb_quantized +from ...utils.hf_training_utils import CustomDataCollatorWithPadding from ...utils.import_utils import ignore_transformers_warnings with ignore_transformers_warnings(): @@ -422,12 +422,13 @@ def test_fsdp_peft(self, qlora, create_datadreamer, mocker): validation_output=val_dataset.output["outputs"], epochs=1, batch_size=8, + gradient_checkpointing=qlora, ) assert data_collator_spy.call_count == 0 trainer_path = cast(str, trainer._output_folder_path) with open(os.path.join(trainer_path, "fingerprint.json"), "r") as f: assert ( - json.load(f) == "ce4179deefbddefd" if qlora else "6b385aca0ce684b3" + json.load(f) == "42a7bd193f804a4a" if qlora else "6b385aca0ce684b3" ) assert train_result is trainer assert ( diff --git a/src/tests/trainers/test_trainers.py b/src/tests/trainers/test_trainers.py index dab6383d..2b114e0a 100644 --- a/src/tests/trainers/test_trainers.py +++ b/src/tests/trainers/test_trainers.py @@ -25,9 +25,9 @@ TrainSentenceTransformer, TrainSetFitClassifier, ) -from ...trainers._train_hf_base import CustomDataCollatorWithPadding from ...utils.fs_utils import clear_dir from ...utils.hf_model_utils import get_orig_model, validate_peft_config +from ...utils.hf_training_utils import CustomDataCollatorWithPadding from ...utils.import_utils import ignore_transformers_warnings with ignore_transformers_warnings(): diff --git a/src/tests/utils/test_device_utils.py b/src/tests/utils/test_device_utils.py index fbb8c326..1ed90b6c 100644 --- a/src/tests/utils/test_device_utils.py +++ b/src/tests/utils/test_device_utils.py @@ -89,4 +89,7 @@ def test_get_device_env_variables(self): get_device_env_variables([0, 2, 999999, 0, 1, -1, -1]) with pytest.raises(AssertionError): get_device_env_variables([0, 2, 0, 1]) - assert get_device_env_variables([0, 2, 1]) == {"CUDA_VISIBLE_DEVICES": "6,3,4"} + assert get_device_env_variables([0, 2, 1]) == { + "CUDA_VISIBLE_DEVICES": "6,3,4", + "NCCL_P2P_DISABLE": "1", + } diff --git a/src/trainers/_train_hf_base.py b/src/trainers/_train_hf_base.py index abcdbd35..ff57b1b6 100644 --- a/src/trainers/_train_hf_base.py +++ b/src/trainers/_train_hf_base.py @@ -1,40 +1,21 @@ import json -import logging import os -import sys from copy import copy from functools import cached_property, partial from io import BytesIO -from itertools import chain from shutil import copy2 -from typing import TYPE_CHECKING, Any, Callable, Type, cast +from typing import TYPE_CHECKING, Any, Type, cast -import numpy as np import torch -from datasets import Dataset, IterableDataset, Value, concatenate_datasets from datasets.fingerprint import Hasher -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR from .. import DataDreamer -from ..datasets import ( - OutputDatasetColumn, - OutputIterableDataset, - OutputIterableDatasetColumn, -) -from ..datasets.datasets import _SizedIterableDataset, get_sized_dataset +from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn from ..logging import logger -from ..steps import DataSource from ..utils.arg_utils import AUTO, DEFAULT, Default, default_to from ..utils.background_utils import RunIfTimeout -from ..utils.device_utils import ( - _TrainingArgumentDeviceOverrideMixin, - get_device_memory_monitoring_callback, - model_to_device, - validate_device, -) +from ..utils.device_utils import model_to_device, validate_device from ..utils.distributed_utils import ( - get_global_rank, get_num_nodes_from_distributed_config, is_distributed, not_distributed_or_main_process, @@ -56,6 +37,7 @@ HF_TRANSFORMERS_CITATION, PEFT_CITATION, convert_dtype, + filter_model_warnings, get_attn_implementation, get_config, get_model_optional_kwargs, @@ -70,907 +52,17 @@ from .trainer import Trainer as DataDreamerTrainer with ignore_transformers_warnings(): - from setfit import logging as setfit_logging from transformers import ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, PreTrainedModel, PreTrainedTokenizer, - TrainerCallback, - logging as hf_transformers_logging, ) from transformers.utils.quantization_config import QuantizationConfigMixin -from transformers import ( - Seq2SeqTrainingArguments as _Seq2SeqTrainingArguments, - TrainingArguments as _TrainingArguments, -) if TYPE_CHECKING: # pragma: no cover - with ignore_transformers_warnings(): - from transformers import Trainer - - -class TrainingArguments(_TrainingArgumentDeviceOverrideMixin, _TrainingArguments): - pass - - -class Seq2SeqTrainingArguments( - _TrainingArgumentDeviceOverrideMixin, _Seq2SeqTrainingArguments -): - pass - - -def _wrap_trainer_cls( - trainer_cls: Type["Trainer"], - optimizers: tuple[None | Optimizer, None | LambdaLR] = (None, None), - optimizer: None | Optimizer = None, - lr_scheduler: None | LambdaLR = None, - compute_loss: None | Callable = None, -) -> Type["Trainer"]: - class WrappedTrainer(trainer_cls): - def create_optimizer(self): - final_optimizer = optimizer or optimizers[0] - if final_optimizer is not None: # pragma: no cover - self.optimizer = final_optimizer - else: - super().create_optimizer() - - def create_scheduler( - self, num_training_steps: int, optimizer: None | Optimizer = None - ): - final_lr_scheduler = lr_scheduler or optimizers[1] - if final_lr_scheduler is not None: # pragma: no cover - self.lr_scheduler = final_lr_scheduler - else: - super().create_scheduler( - num_training_steps=num_training_steps, optimizer=optimizer - ) - - def compute_loss(self, model, inputs, return_outputs=False): - if compute_loss is not None: # pragma: no cover - return compute_loss(model, inputs, return_outputs=return_outputs) - else: - return super().compute_loss( - model, inputs, return_outputs=return_outputs - ) - - return WrappedTrainer - - -def _prepare_inputs_and_outputs( # noqa: C901 - self: "_TrainHFBase", - train_columns: dict[ - tuple[str, str], OutputDatasetColumn | OutputIterableDatasetColumn - ], - validation_columns: dict[ - tuple[str, str], OutputDatasetColumn | OutputIterableDatasetColumn - ], - truncate: bool = False, - causal: bool = False, - dpo: bool = False, - reward_pairs: bool = False, - reward_scores: bool = False, -) -> tuple[ - Dataset | IterableDataset | _SizedIterableDataset, - Dataset | IterableDataset | _SizedIterableDataset, - dict[Any, int], - bool, -]: - num_proc = ( - ( - len(os.sched_getaffinity(0)) - if hasattr(os, "sched_getaffinity") - else os.cpu_count() - ) - if sys.platform != "darwin" - else 1 - ) - label2id: dict[Any, int] = {} - is_multi_target: bool = False - - def get_train_column( - column_name: str, - ) -> OutputDatasetColumn | OutputIterableDatasetColumn: - for (train_column_name, _), train_column in train_columns.items(): - if train_column_name == column_name: - return train_column - raise KeyError(f"Train column {column_name} not found.") # pragma: no cover - - def get_validation_column( - column_name: str, - ) -> OutputDatasetColumn | OutputIterableDatasetColumn: - for ( - validation_column_name, - _, - ), validation_column in validation_columns.items(): - if validation_column_name == column_name: - return validation_column - raise KeyError( - f"Validation column {column_name} not found." - ) # pragma: no cover - - def apply_chat_prompt_template(prompt: str) -> str: - return ( - cast(str, self.chat_prompt_template) - .replace("{{system_prompt}}", self.system_prompt or "") - .replace("{{prompt}}", prompt) - ) - - def tokenize_function( - examples, - column_name: str, - new_column_name: str, - causal: bool, - reward_scores: bool, - ): # pragma: no cover - if reward_scores: - prompt, completion = examples[column_name] - if self.chat_prompt_template: - prompt = apply_chat_prompt_template(prompt) - input_ids = self.tokenizer( - prompt + completion, - truncation=truncate, - padding=False, - add_special_tokens=True, - )["input_ids"] - return { - "input_ids": input_ids[: self.tokenizer.model_max_length] - if truncate - else input_ids, - "labels": examples["label"], - } - elif causal: - prompt, completion = examples[column_name] - if self.chat_prompt_template: - prompt = apply_chat_prompt_template(prompt) - prompt_input_ids = self.tokenizer( - prompt, truncation=truncate, padding=False, add_special_tokens=True - )["input_ids"] - completion_input_ids = self.tokenizer( - completion, truncation=truncate, padding=False, add_special_tokens=False - )["input_ids"] + [self.tokenizer.eos_token_id] - prompt_labels = [-100] * len(prompt_input_ids) - input_ids = prompt_input_ids + completion_input_ids - labels = prompt_labels + completion_input_ids - return { - "input_ids": input_ids[: self.tokenizer.model_max_length] - if truncate - else input_ids, - "labels": labels[: self.tokenizer.model_max_length] - if truncate - else labels, - } - elif new_column_name in ["decoder_labels"]: - return { - "labels": self.tokenizer( - examples[column_name], - truncation=truncate, - padding=False, - add_special_tokens=True, - )["input_ids"] - } - else: - prompts = examples[column_name] - if self.chat_prompt_template: - prompts = list(map(apply_chat_prompt_template, prompts)) - tokenizer_results = self.tokenizer( - prompts, truncation=truncate, padding=False, add_special_tokens=True - ) - return { - new_column_name: tokenizer_results["input_ids"], - f"{new_column_name.replace('input_ids', '')}attention_mask": tokenizer_results[ - "attention_mask" - ], - } - - def tokenize_column_name( - column_name: str, - new_column_name: str, - causal: bool, - reward_scores: bool = False, - ) -> Callable: - return partial( - tokenize_function, - column_name=column_name, - new_column_name=new_column_name, - causal=causal, - reward_scores=reward_scores, - ) - - def tokenize_column( - column: OutputDatasetColumn | OutputIterableDatasetColumn, - new_column_name: str, - name: str, - causal: bool = False, - reward_scores: bool = False, - ) -> Dataset | IterableDataset: - column_name = column.column_names[0] - return column.step.map( - name=f"Tokenize {name}", - function=tokenize_column_name( - column_name, - new_column_name=new_column_name, - causal=causal, - reward_scores=reward_scores, - ), - batched=not causal and not reward_scores, - remove_columns=column.step.output.column_names, - total_num_rows=column.num_rows, - auto_progress=column.num_rows is not None, - lazy=isinstance(column, OutputIterableDatasetColumn), - progress_interval=sys.maxsize - if isinstance(column, OutputIterableDatasetColumn) - else 120, - save_num_proc=num_proc, - ).output.dataset - - def rename_column( - column: OutputDatasetColumn | OutputIterableDatasetColumn, new_column_name: str - ) -> Dataset | IterableDataset: - column_name = column.column_names[0] - column_dataset = column.step.output.dataset.select_columns(column.column_names) - return ( - column_dataset.rename_column(column_name, new_column_name) - if column_name != new_column_name - else column_dataset - ) - - def label_encode_function( - _, column_name: str, example: dict[str, Any] - ) -> dict[str, Any]: # pragma: no cover - if isinstance(example[column_name], list): - row_labels = set(str(label) for label in example[column_name]) - return { - column_name: [1 if label in row_labels else 0 for label in label2id] - } - else: - return {column_name: label2id[str(example[column_name])]} - - def label2id_column( - column: OutputDatasetColumn | OutputIterableDatasetColumn, - new_column_name: str, - name: str, - ) -> Dataset | IterableDataset: - column_name = column.column_names[0] - return rename_column( - column.step.map( - name=f"Encode {name} labels", - function=partial( - label_encode_function, sorted(label2id.keys()), column_name - ), - batched=False, - remove_columns=list( - set(column.step.output.column_names).difference(set([column_name])) - ), - total_num_rows=column.num_rows, - auto_progress=column.num_rows is not None, - lazy=isinstance(column, OutputIterableDatasetColumn), - progress_interval=sys.maxsize - if isinstance(column, OutputIterableDatasetColumn) - else 120, - save_num_proc=num_proc, - ).output[column_name], - new_column_name, - ) - - def process_column( - column: OutputDatasetColumn | OutputIterableDatasetColumn, - new_column_name: str, - name: str, - ) -> Dataset | IterableDataset: - if new_column_name == "label" and reward_scores is False: - return label2id_column( - column=column, new_column_name=new_column_name, name=name - ) - else: # pragma: no cover - return rename_column(column=column, new_column_name=new_column_name) - - def concatenate_prompts_and_completions( - dataset: Dataset | IterableDataset, - ) -> IterableDataset: - iterable_dataset = ( - dataset.to_iterable_dataset() if isinstance(dataset, Dataset) else dataset - ) - return iterable_dataset.map( - lambda row: {"text": [row["prompt"], row["completion"]]}, - remove_columns=["prompt", "completion"], - ) - - # Calculate label2id - uniq_labels = [] - for (new_column_name, name), column in list(train_columns.items()) + list( - validation_columns.items() - ): - column_name = column.column_names[0] - - def uniqify_labels(labels: set[Any], column_name, example): - nonlocal is_multi_target - if isinstance(example[column_name], list): - is_multi_target = True - is_new = False - for label in example[column_name]: - if label not in labels: - is_new = True - labels.add(label) - return is_new - else: - is_new = example[column_name] not in labels - labels.add(example[column_name]) - return is_new - - if new_column_name == "label" and reward_scores is False: - uniq_labels_column = column.step.filter( - name=f"Get all {name} label names", - function=partial(uniqify_labels, set(), column_name), - batched=False, - total_num_rows=column.num_rows, - auto_progress=column.num_rows is not None, - lazy=False, - progress_interval=sys.maxsize - if isinstance(column, OutputIterableDatasetColumn) - else 120, - ).output[column_name] - uniq_labels_from_column = list(uniq_labels_column) - uniq_labels += ( - list(chain.from_iterable(uniq_labels_column)) - if len(uniq_labels_from_column) > 0 - and isinstance(uniq_labels_from_column[0], list) - else uniq_labels_column - ) - uniq_labels = sorted(set(uniq_labels)) - for label in uniq_labels: - label2id[str(label)] = len(label2id) - - # Create train and validation datasets - train_dataset: Dataset | IterableDataset - validation_dataset: Dataset | IterableDataset - if reward_pairs: - # Check if scores are provided - try: - get_train_column("train_chosen_scores") - has_scores = True - except KeyError: - has_scores = False - - # Get data collator - def prepare_for_reward_pairs(row): # pragma: no cover - row = row.copy() - if self.chat_prompt_template: - row["prompt"] = apply_chat_prompt_template(row["prompt"]) - row["chosen"] = row["prompt"] + row["chosen"] - row["rejected"] = row["prompt"] + row["rejected"] - reward_results = {} - chosen_tokenizer_results = self.tokenizer( - row["chosen"], - truncation=truncate, - padding=False, - add_special_tokens=True, - ) - reward_results["input_ids_chosen"] = chosen_tokenizer_results["input_ids"] - rejected_tokenizer_results = self.tokenizer( - row["rejected"], - truncation=truncate, - padding=False, - add_special_tokens=True, - ) - reward_results["input_ids_rejected"] = rejected_tokenizer_results[ - "input_ids" - ] - if "chosen_scores" in row and "rejected_scores" in row: - reward_results["margin"] = row["chosen_scores"] - row["rejected_scores"] - return reward_results - - # Run data collator - train_columns_to_combine = [ - rename_column(get_train_column("train_prompts"), "prompt"), - rename_column(get_train_column("train_chosen"), "chosen"), - rename_column(get_train_column("train_rejected"), "rejected"), - ] - if has_scores: - train_columns_to_combine.extend( - [ - rename_column( - get_train_column("train_chosen_scores"), "chosen_scores" - ), - rename_column( - get_train_column("train_rejected_scores"), "rejected_scores" - ), - ] - ) - train_combine_step = DataSource( - "Combine Train Prompts, Chosen Generations, and Rejected Generations", - data=concatenate_datasets(train_columns_to_combine, axis=1), - total_num_rows=get_train_column("train_prompts").num_rows, - auto_progress=get_train_column("train_prompts").num_rows is not None, - ) - train_dataset = train_combine_step.map( - name="Prepare Train Dataset for Reward Model Training", - function=prepare_for_reward_pairs, - batched=False, - remove_columns=train_combine_step.output.column_names, - total_num_rows=get_train_column("train_prompts").num_rows, - auto_progress=get_train_column("train_prompts").num_rows is not None, - lazy=isinstance(train_combine_step.output, OutputIterableDataset), - progress_interval=sys.maxsize - if isinstance(train_combine_step.output, OutputIterableDataset) - else 120, - save_num_proc=num_proc, - ).output.dataset - validation_columns_to_combine = [ - rename_column(get_validation_column("validation_prompts"), "prompt"), - rename_column(get_validation_column("validation_chosen"), "chosen"), - rename_column(get_validation_column("validation_rejected"), "rejected"), - ] - if has_scores: - validation_columns_to_combine.extend( - [ - rename_column( - get_validation_column("validation_chosen_scores"), - "chosen_scores", - ), - rename_column( - get_validation_column("validation_rejected_scores"), - "rejected_scores", - ), - ] - ) - validation_combine_step = DataSource( - "Combine Validation Prompts, Chosen Generations, and Rejected Generations", - data=concatenate_datasets(validation_columns_to_combine, axis=1), - total_num_rows=get_validation_column("validation_prompts").num_rows, - auto_progress=get_validation_column("validation_prompts").num_rows - is not None, - ) - validation_dataset = validation_combine_step.map( - name="Prepare Validation Dataset for Reward Model Training", - function=prepare_for_reward_pairs, - batched=False, - remove_columns=validation_combine_step.output.column_names, - total_num_rows=get_validation_column("validation_prompts").num_rows, - auto_progress=get_validation_column("validation_prompts").num_rows - is not None, - lazy=isinstance(validation_combine_step.output, OutputIterableDataset), - progress_interval=sys.maxsize - if isinstance(validation_combine_step.output, OutputIterableDataset) - else 120, - save_num_proc=num_proc, - ).output.dataset - elif dpo: - if TYPE_CHECKING: # pragma: no cover - DPODataCollatorWithPadding: Any = None - else: - from ._vendored._dpo_helper import DPODataCollatorWithPadding - - # Get data collator - data_collator = DPODataCollatorWithPadding( - tokenizer=self.tokenizer, - max_length=self.tokenizer.model_max_length if truncate else sys.maxsize, - max_prompt_length=self.tokenizer.model_max_length - if truncate - else sys.maxsize, - label_pad_token_id=-100, - padding_value=0, - truncation_mode="keep_end", - is_encoder_decoder=self._is_encoder_decoder, - max_target_length=self.tokenizer.model_max_length - if truncate - else sys.maxsize, - ) - - def run_data_collator(row): # pragma: no cover - if self.chat_prompt_template: - row["prompt"] = apply_chat_prompt_template(row["prompt"]) - dpo_results = data_collator.__call__([row]) - for key, value in list(dpo_results.items()): - if "attention_mask" in key: - del dpo_results[key] - elif isinstance(value, list) and len(value) == 1: - dpo_results[key] = value[0] - elif isinstance(value, torch.Tensor) and len(value.shape) == 2: - value = value[0] - if truncate: - dpo_results[key] = value[: self.tokenizer.model_max_length] - return dpo_results - - # Run data collator - train_combine_step = DataSource( - "Combine Train Prompts, Chosen Generations, and Rejected Generations", - data=concatenate_datasets( - [ - rename_column(get_train_column("train_prompts"), "prompt"), - rename_column(get_train_column("train_chosen"), "chosen"), - rename_column(get_train_column("train_rejected"), "rejected"), - ], - axis=1, - ), - total_num_rows=get_train_column("train_prompts").num_rows, - auto_progress=get_train_column("train_prompts").num_rows is not None, - ) - train_dataset = train_combine_step.map( - name="Prepare Train Dataset for DPO", - function=run_data_collator, - batched=False, - total_num_rows=get_train_column("train_prompts").num_rows, - auto_progress=get_train_column("train_prompts").num_rows is not None, - lazy=isinstance(train_combine_step.output, OutputIterableDataset), - progress_interval=sys.maxsize - if isinstance(train_combine_step.output, OutputIterableDataset) - else 120, - save_num_proc=num_proc, - ).output.dataset - validation_combine_step = DataSource( - "Combine Validation Prompts, Chosen Generations, and Rejected Generations", - data=concatenate_datasets( - [ - rename_column( - get_validation_column("validation_prompts"), "prompt" - ), - rename_column(get_validation_column("validation_chosen"), "chosen"), - rename_column( - get_validation_column("validation_rejected"), "rejected" - ), - ], - axis=1, - ), - total_num_rows=get_validation_column("validation_prompts").num_rows, - auto_progress=get_validation_column("validation_prompts").num_rows - is not None, - ) - validation_dataset = validation_combine_step.map( - name="Prepare Validation Dataset for DPO", - function=run_data_collator, - batched=False, - total_num_rows=get_validation_column("validation_prompts").num_rows, - auto_progress=get_validation_column("validation_prompts").num_rows - is not None, - lazy=isinstance(validation_combine_step.output, OutputIterableDataset), - progress_interval=sys.maxsize - if isinstance(validation_combine_step.output, OutputIterableDataset) - else 120, - save_num_proc=num_proc, - ).output.dataset - elif reward_scores: - train_combined = concatenate_datasets( - [ - rename_column(get_train_column("train_input"), "prompt"), - rename_column(get_train_column("train_output"), "completion"), - rename_column(get_train_column("label"), "label").cast_column( - "label", Value("float64") - ), - ], - axis=1, - ) - train_dataset = tokenize_column( - DataSource( - "Concatenate Train Prompts and Generations", - data=concatenate_prompts_and_completions(train_combined), - total_num_rows=get_train_column("train_input").num_rows, - auto_progress=get_train_column("train_input").num_rows is not None, - save=not isinstance(train_combined, IterableDataset), - ).output["text"], - "input_ids", - "Train Dataset", - reward_scores=True, - ) - validation_combined = concatenate_datasets( - [ - rename_column(get_validation_column("validation_input"), "prompt"), - rename_column(get_validation_column("validation_output"), "completion"), - rename_column(get_validation_column("label"), "label").cast_column( - "label", Value("float64") - ), - ], - axis=1, - ) - validation_dataset = tokenize_column( - DataSource( - "Concatenate Validation Prompts and Generations", - data=concatenate_prompts_and_completions(validation_combined), - total_num_rows=get_validation_column("validation_input").num_rows, - auto_progress=get_validation_column("validation_input").num_rows - is not None, - save=not isinstance(validation_combined, IterableDataset), - ).output["text"], - "input_ids", - "Validation Dataset", - reward_scores=True, - ) - elif causal: - train_combined = concatenate_datasets( - [ - rename_column(get_train_column("train_input"), "prompt"), - rename_column(get_train_column("train_output"), "completion"), - ], - axis=1, - ) - train_dataset = tokenize_column( - DataSource( - "Concatenate Train Input and Output", - data=concatenate_prompts_and_completions(train_combined), - total_num_rows=get_train_column("train_input").num_rows, - auto_progress=get_train_column("train_input").num_rows is not None, - save=not isinstance(train_combined, IterableDataset), - ).output["text"], - "input_ids", - "Train Dataset", - causal=True, - ) - validation_combined = concatenate_datasets( - [ - rename_column(get_validation_column("validation_input"), "prompt"), - rename_column(get_validation_column("validation_output"), "completion"), - ], - axis=1, - ) - validation_dataset = tokenize_column( - DataSource( - "Concatenate Validation Input and Output", - data=concatenate_prompts_and_completions(validation_combined), - total_num_rows=get_validation_column("validation_input").num_rows, - auto_progress=get_validation_column("validation_input").num_rows - is not None, - save=not isinstance(validation_combined, IterableDataset), - ).output["text"], - "input_ids", - "Validation Dataset", - causal=True, - ) - else: - train_dataset = concatenate_datasets( - [ - tokenize_column(train_column, train_column_name, name) - if train_column_name in ["input_ids", "decoder_labels"] - or train_column_name.endswith("_input_ids") - else process_column(train_column, train_column_name, name) - for (train_column_name, name), train_column in train_columns.items() - ], - axis=1, - ) - validation_dataset = concatenate_datasets( - [ - tokenize_column(validation_column, validation_column_name, name) - if validation_column_name in ["input_ids", "decoder_labels"] - or validation_column_name.endswith("_input_ids") - else process_column(validation_column, validation_column_name, name) - for ( - validation_column_name, - name, - ), validation_column in validation_columns.items() - ], - axis=1, - ) - - # Save information for publishing - train_step = list(train_columns.values())[0].step - self._step_metadata = train_step._get_metadata(train_step.output) - - # Save information for publishing - self._examples = { - name: ( - train_column.dataset[:3][train_column.column_names[0]] - if isinstance(train_column.dataset, Dataset) - else list( - map( - lambda row: row[train_column.column_names[0]], - train_column.dataset.take(3), - ) - ) - ) - for (_, name), train_column in train_columns.items() - } - if reward_scores: - if self.chat_prompt_template: - self._examples["Train Prompts"] = [ - apply_chat_prompt_template(prompt) - for prompt in self._examples["Train Prompts"] - ] - self._examples["Train Input"] = [ - prompt + generation - for prompt, generation in zip( - self._examples["Train Prompts"], self._examples["Train Generations"] - ) - ] - elif reward_pairs: - if self.chat_prompt_template: - self._examples["Train Prompts"] = [ - apply_chat_prompt_template(prompt) - for prompt in self._examples["Train Prompts"] - ] - chosen_examples = [ - prompt + generation - for prompt, generation in zip( - self._examples["Train Prompts"], - self._examples["Train Chosen Generations"], - ) - ] - rejected_examples = [ - prompt + generation - for prompt, generation in zip( - self._examples["Train Prompts"], - self._examples["Train Rejected Generations"], - ) - ] - self._examples["Train Input"] = list( - chain.from_iterable(zip(chosen_examples, rejected_examples)) - ) - elif dpo: - self._examples["Train Input"] = self._examples["Train Prompts"] - - # Return datasets - return ( - get_sized_dataset( - dataset=train_dataset, - total_num_rows=list(train_columns.values())[0].num_rows, - ), - get_sized_dataset( - dataset=validation_dataset, - total_num_rows=list(validation_columns.values())[0].num_rows, - ), - label2id, - is_multi_target, - ) - - -def _start_hf_trainer(self: "_TrainHFBase", trainer: Any): # noqa: C901 - # Setup loggers the way we need them to be - if not DataDreamer.ctx.hf_log: - if self.logger.level <= logging.NOTSET: # pragma: no cover - hf_transformers_trainer_logger = logging.getLogger("transformers.trainer") - if ( - not hf_transformers_trainer_logger.level - or hf_transformers_trainer_logger.level > logging.INFO - ): - hf_transformers_trainer_logger.level = logging.INFO - hf_transformers_trainer_logger.propagate = True - DataDreamer._enable_hf_transformers_logging(progress_bars=False) - DataDreamer._enable_setfit_logging(progress_bars=False) - hf_transformers_logging.set_verbosity_info() - setfit_logging.set_verbosity_info() - - # Add GPU monitoring if distributed - device_memory_monitoring_callback = get_device_memory_monitoring_callback( - trainer=self - ) - trainer.add_callback(device_memory_monitoring_callback) - - # Run training - try: - # Try to resume - if self.resumable: - trainer.train(resume_from_checkpoint=True) - else: - raise ValueError() - except ValueError: - # Nothing to resume from, so start a new training run - - # Evaluate before starting training so we can see how the model - # performs before any weight updates - if device_memory_monitoring_callback: - device_memory_monitoring_callback()._log_device_memory_usage() - if is_distributed() and trainer.is_fsdp_enabled: # pragma: no cover - from transformers.trainer import logger as trainer_logger - - # This is a hack to run .evaluate() before training happens on FSDP - # but after the FSDP is set up - old_info = trainer_logger.info - - def _info(old_info, *args, **kwargs): - if len(args) > 0 and args[0].startswith("***** Running training *****"): - trainer.evaluate() - trainer.model.train() # Switch the model back to train mode - trainer_logger.info = old_info # Undo the monkey-patch - return old_info(*args, **kwargs) - - trainer_logger.info = partial(_info, old_info) - else: - trainer.evaluate() - - # Start training - trainer.train() - if not DataDreamer.ctx.hf_log: - if self.logger.level <= logging.NOTSET: # pragma: no cover - logging.getLogger( - "transformers.trainer" - ).level = DataDreamer.ctx._transformers_trainer_verbosity - DataDreamer._disable_hf_transformers_logging() - DataDreamer._disable_setfit_logging() - - -class CustomDataCollatorWithPadding: - def __init__( - self, - tokenizer: PreTrainedTokenizer, - fields_to_pad: list[dict[str, Any]], - fields_to_keep: None | list[str] = None, - extra_column_names_to_add: None | dict[str, Any] = None, - ): - self.tokenizer = tokenizer - self.fields_to_pad = fields_to_pad - self.fields_to_keep = fields_to_keep - self.extra_column_names_to_add = extra_column_names_to_add - - def update_pad_token_id( - self, tensor: torch.Tensor, pad_token_id: int, keep_first_pad_token: bool - ): - # Find where the pad tokens are - pad_token_mask = tensor == self.tokenizer.pad_token_id - if keep_first_pad_token: - # Find the indices of the left-most pad token in each row - leftmost_true_indices = pad_token_mask.to(torch.int32).argmax(dim=1) - # Create a mask to help keep the left-most pad_token value - keep_leftmost_mask = ( - torch.arange(pad_token_mask.size(1)) <= leftmost_true_indices[:, None] - ) - # Apply the mask to the original mask - pad_token_mask = pad_token_mask & ~keep_leftmost_mask - # Update the pad token IDs - tensor[pad_token_mask] = pad_token_id - - def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: - result = {} - for field in self.fields_to_pad: - tokenizer = field.get("tokenizer", self.tokenizer) - pad_results = tokenizer.pad( - [{"input_ids": feature[field["name"]]} for feature in features], - padding=True, - return_tensors="pt", - ) - result[field["output_name"]] = pad_results["input_ids"] - if "pad_token_id" in field: - self.update_pad_token_id( - tensor=result[field["output_name"]], - pad_token_id=field["pad_token_id"], - keep_first_pad_token=field.get("keep_first_pad_token", False), - ) - if "output_attention_mask_name" in field: # pragma: no cover - result[field["output_attention_mask_name"]] = pad_results[ - "attention_mask" - ] - if isinstance(self.extra_column_names_to_add, dict): - for ( - column_name, - default_value, - ) in self.extra_column_names_to_add.items(): - result[column_name] = default_value - if self.fields_to_keep is not None: - for field_name in self.fields_to_keep: - result[field_name] = [ - feature[field_name] for feature in features if field_name in feature - ] - if len(result[field_name]) > 0 and isinstance( - result[field_name][0], (bool, int, float, np.ndarray, torch.Tensor) - ): - result[field_name] = torch.tensor(result[field_name]) - elif len(result[field_name]) == 0: - del result[field_name] - return result - - -def get_logging_callback(trainer: "_TrainHFBase", log_loss: bool = True) -> Type: - class LoggingCallback(TrainerCallback): - def on_log(self_, args, state, control, logs=None, **kwargs): - if is_distributed() and get_global_rank() != 0: # pragma: no cover - return - logs = logs.copy() - if "eval_progress" in logs and logs["eval_progress"] == "100%": - return - _ = logs.pop("total_flos", None) - _ = logs.pop("eval_joint_metric", None) - if state.is_local_process_zero: - epoch = logs.pop("epoch", 0.0) - if any([metric.startswith("eval_") for metric in logs.keys()]): - logs = {k.replace("eval_", ""): v for k, v in logs.items()} - if not log_loss: - logs.pop("loss") - trainer.logger.info(f"Eval Epoch: {epoch} -- {logs}") - else: - logs = {k.replace("train_", ""): v for k, v in logs.items()} - if not log_loss: - logs.pop("loss") - trainer.logger.info(f"Train Epoch: {epoch} -- {logs}") - - return LoggingCallback + from ..utils.hf_training_utils import TrainingArguments class _TrainHFBase(DataDreamerTrainer): @@ -1148,7 +240,9 @@ def _create_model( from peft import get_peft_model, prepare_model_for_kbit_training if self.quantization_config: # pragma: no cover - model = prepare_model_for_kbit_training(model) + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=True + ) model = get_peft_model( model, validate_peft_config(model=model, peft_config=self.peft_config) ) @@ -1160,6 +254,9 @@ def _create_model( else: model.train() + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial(lambda self: self.logger.info("Finished loading."), self) @@ -1188,7 +285,7 @@ def _publish_resource( def _save_model( self, - training_args: None | TrainingArguments, + training_args: "None | TrainingArguments", model: PreTrainedModel, tokenizer: PreTrainedTokenizer, accelerator: Any = None, @@ -1355,6 +452,9 @@ def _load_model( # model = torch.compile(model) pass + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial(lambda self: self.logger.info("Finished loading."), self) diff --git a/src/trainers/train_hf_classifier.py b/src/trainers/train_hf_classifier.py index 3b5dd2a2..e3597dd1 100644 --- a/src/trainers/train_hf_classifier.py +++ b/src/trainers/train_hf_classifier.py @@ -10,18 +10,20 @@ from torch.nn import functional as F from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn +from ..trainers.trainer import JointMetric from ..utils.arg_utils import AUTO, Default from ..utils.distributed_utils import not_distributed_or_main_process -from ..utils.import_utils import ignore_transformers_warnings -from ._train_hf_base import ( +from ..utils.hf_training_utils import ( TrainingArguments, - _prepare_inputs_and_outputs, - _start_hf_trainer, - _TrainHFBase, - _wrap_trainer_cls, + _monkey_patch_TrainerState__post_init__, get_logging_callback, + prepare_inputs_and_outputs, + start_hf_trainer, + wrap_compute_metrics, + wrap_trainer_cls, ) -from .trainer import JointMetric, _monkey_patch_TrainerState__post_init__ +from ..utils.import_utils import ignore_transformers_warnings +from ._train_hf_base import _TrainHFBase with ignore_transformers_warnings(): from transformers import ( @@ -115,7 +117,7 @@ def _train( # type:ignore[override] validation_dataset, label2id, is_multi_target, - ) = _prepare_inputs_and_outputs( + ) = prepare_inputs_and_outputs( self, train_columns={ ("input_ids", "Train Input"): train_input, @@ -253,6 +255,7 @@ def compute_accuracy_metrics(accuracy, f1, eval_pred): weight_decay=weight_decay, lr_scheduler_type=lr_scheduler_type, warmup_steps=warmup_steps, + eval_accumulation_steps=kwargs.pop("eval_accumulation_steps", 1), logging_strategy=kwargs.pop("logging_strategy", None) or "steps", logging_steps=kwargs.pop("logging_steps", 1), evaluation_strategy=kwargs.pop("evaluation_strategy", None) or "epoch", @@ -269,7 +272,7 @@ def compute_accuracy_metrics(accuracy, f1, eval_pred): ) # Setup trainer - trainer = _wrap_trainer_cls( + trainer = wrap_trainer_cls( trainer_cls=trainer_cls or Trainer, **trainer_override_kwargs )( train_dataset=train_dataset, @@ -277,7 +280,9 @@ def compute_accuracy_metrics(accuracy, f1, eval_pred): model=model, tokenizer=self.tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics, + compute_metrics=wrap_compute_metrics( + compute_metrics=compute_metrics, training_args=training_args + ), callbacks=callbacks, preprocess_logits_for_metrics=preprocess_logits_for_metrics, args=training_args, @@ -285,7 +290,7 @@ def compute_accuracy_metrics(accuracy, f1, eval_pred): trainer.remove_callback(PrinterCallback) # Start the trainer - _start_hf_trainer(self, trainer) + start_hf_trainer(self, trainer) # Save the model to disk self._save_model( diff --git a/src/trainers/train_hf_dpo.py b/src/trainers/train_hf_dpo.py index 9e30df02..e39e3387 100644 --- a/src/trainers/train_hf_dpo.py +++ b/src/trainers/train_hf_dpo.py @@ -9,16 +9,17 @@ from ..steps.step_operations import _INTERNAL_STEP_OPERATION_KEY from ..utils.arg_utils import AUTO, Default from ..utils.distributed_utils import is_distributed, not_main_process -from ..utils.import_utils import ignore_transformers_warnings, ignore_trl_warnings -from ._train_hf_base import ( +from ..utils.hf_training_utils import ( CustomDataCollatorWithPadding, Seq2SeqTrainingArguments, TrainingArguments, - _prepare_inputs_and_outputs, - _start_hf_trainer, - _wrap_trainer_cls, get_logging_callback, + prepare_inputs_and_outputs, + start_hf_trainer, + wrap_compute_metrics, + wrap_trainer_cls, ) +from ..utils.import_utils import ignore_transformers_warnings, ignore_trl_warnings from .train_hf_finetune import TrainHFFineTune with ignore_transformers_warnings(): @@ -116,7 +117,7 @@ def _train( # type:ignore[override] # noqa: C901 assert ( self._is_encoder_decoder or truncate ), "`truncate=False` is not supported for this model." - train_dataset, validation_dataset, _, _ = _prepare_inputs_and_outputs( + train_dataset, validation_dataset, _, _ = prepare_inputs_and_outputs( self, train_columns={ ("train_prompts", "Train Prompts"): train_prompts, @@ -217,6 +218,7 @@ def _train( # type:ignore[override] # noqa: C901 weight_decay=weight_decay, lr_scheduler_type=lr_scheduler_type, warmup_steps=warmup_steps, + eval_accumulation_steps=kwargs.pop("eval_accumulation_steps", 1), logging_strategy=kwargs.pop("logging_strategy", None) or "steps", logging_steps=kwargs.pop("logging_steps", 1), evaluation_strategy=kwargs.pop("evaluation_strategy", None) or "epoch", @@ -314,7 +316,7 @@ def _train( # type:ignore[override] # noqa: C901 ] + other_fields_to_keep, ) - trainer = _wrap_trainer_cls( + trainer = wrap_trainer_cls( trainer_cls=trainer_cls or DPOTrainer, **trainer_override_kwargs )( label_pad_token_id=-100, @@ -326,7 +328,9 @@ def _train( # type:ignore[override] # noqa: C901 ref_model=ref_model, tokenizer=self.tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics, + compute_metrics=wrap_compute_metrics( + compute_metrics=compute_metrics, training_args=training_args + ), callbacks=callbacks, preprocess_logits_for_metrics=preprocess_logits_for_metrics, args=training_args, @@ -428,7 +432,7 @@ def pre_compute_eval(): assert os.path.isfile(pre_compute_validation_step_done) # Start the trainer - _start_hf_trainer(self, trainer) + start_hf_trainer(self, trainer) # Save the model to disk self._save_model( diff --git a/src/trainers/train_hf_finetune.py b/src/trainers/train_hf_finetune.py index 67150606..e190b968 100644 --- a/src/trainers/train_hf_finetune.py +++ b/src/trainers/train_hf_finetune.py @@ -5,17 +5,18 @@ from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn from ..utils.arg_utils import AUTO, Default -from ..utils.import_utils import ignore_transformers_warnings -from ._train_hf_base import ( +from ..utils.hf_training_utils import ( CustomDataCollatorWithPadding, Seq2SeqTrainingArguments, TrainingArguments, - _prepare_inputs_and_outputs, - _start_hf_trainer, - _TrainHFBase, - _wrap_trainer_cls, get_logging_callback, + prepare_inputs_and_outputs, + start_hf_trainer, + wrap_compute_metrics, + wrap_trainer_cls, ) +from ..utils.import_utils import ignore_transformers_warnings +from ._train_hf_base import _TrainHFBase with ignore_transformers_warnings(): from transformers import ( @@ -107,7 +108,7 @@ def _train( # type:ignore[override] assert ( self._is_encoder_decoder or truncate ), "`truncate=False` is not supported for this model." - train_dataset, validation_dataset, _, _ = _prepare_inputs_and_outputs( + train_dataset, validation_dataset, _, _ = prepare_inputs_and_outputs( self, train_columns={ ( @@ -136,19 +137,37 @@ def _train( # type:ignore[override] ) # Prepare compute metrics + + # This computation can use a fair bit of CPU RAM due to the size of these + # tensors (batch_size * sequence_length * vocabulary_size), so we should try + # to save as much memory as possible + compute_perplexity_dtype = torch.float16 + try: + torch.nn.functional.cross_entropy( + input=torch.tensor([[1.0]], dtype=compute_perplexity_dtype), + target=torch.tensor([0], dtype=torch.long), + ) + except RuntimeError: + compute_perplexity_dtype = torch.float32 + def compute_perplexity_metrics(eval_pred): preds, labels = eval_pred + del eval_pred if isinstance(preds, tuple): preds = preds[0] - preds = torch.tensor(preds) + preds = torch.tensor(preds, dtype=compute_perplexity_dtype) labels = torch.tensor(labels) if self._is_encoder_decoder: nll = torch.nn.functional.cross_entropy( input=preds.view(-1, preds.size(-1)), target=labels.view(-1) ) else: + preds = preds.to(compute_perplexity_dtype) + labels = labels shift_preds = preds[..., :-1, :].contiguous() + del preds shift_labels = labels[..., 1:].contiguous() + del labels nll = torch.nn.functional.cross_entropy( input=shift_preds.view(-1, shift_preds.size(-1)), target=shift_labels.view(-1), @@ -216,6 +235,7 @@ def compute_perplexity_metrics(eval_pred): weight_decay=weight_decay, lr_scheduler_type=lr_scheduler_type, warmup_steps=warmup_steps, + eval_accumulation_steps=kwargs.pop("eval_accumulation_steps", 1), logging_strategy=kwargs.pop("logging_strategy", None) or "steps", logging_steps=kwargs.pop("logging_steps", 1), evaluation_strategy=kwargs.pop("evaluation_strategy", None) or "epoch", @@ -254,12 +274,14 @@ def compute_perplexity_metrics(eval_pred): ) trainer_cls = trainer_cls or Trainer trainer_args = {"data_collator": data_collator} - trainer = _wrap_trainer_cls(trainer_cls=trainer_cls, **trainer_override_kwargs)( + trainer = wrap_trainer_cls(trainer_cls=trainer_cls, **trainer_override_kwargs)( train_dataset=train_dataset, eval_dataset=validation_dataset, model=model, tokenizer=self.tokenizer, - compute_metrics=compute_metrics, + compute_metrics=wrap_compute_metrics( + compute_metrics=compute_metrics, training_args=training_args + ), callbacks=callbacks, preprocess_logits_for_metrics=preprocess_logits_for_metrics, args=training_args, @@ -268,7 +290,7 @@ def compute_perplexity_metrics(eval_pred): trainer.remove_callback(PrinterCallback) # Start the trainer - _start_hf_trainer(self, trainer) + start_hf_trainer(self, trainer) # Save the model to disk self._save_model( diff --git a/src/trainers/train_hf_ppo.py b/src/trainers/train_hf_ppo.py index c1ea4e1c..18c9a4b3 100644 --- a/src/trainers/train_hf_ppo.py +++ b/src/trainers/train_hf_ppo.py @@ -18,14 +18,15 @@ from ..utils.arg_utils import AUTO, Default, default_to from ..utils.fs_utils import mkdir from ..utils.hf_model_utils import is_peft_model -from ..utils.import_utils import ignore_transformers_warnings, ignore_trl_warnings -from ._train_hf_base import ( +from ..utils.hf_training_utils import ( CustomDataCollatorWithPadding, TrainingArguments, - _prepare_inputs_and_outputs, - _start_hf_trainer, get_logging_callback, + prepare_inputs_and_outputs, + start_hf_trainer, + wrap_compute_metrics, ) +from ..utils.import_utils import ignore_transformers_warnings, ignore_trl_warnings from .train_hf_finetune import TrainHFFineTune from .train_hf_reward_model import TrainHFRewardModel @@ -507,7 +508,7 @@ def _train( # type:ignore[override] # noqa: C901 assert ( self._is_encoder_decoder or truncate ), "`truncate=False` is not supported for this model." - train_dataset, validation_dataset, _, _ = _prepare_inputs_and_outputs( + train_dataset, validation_dataset, _, _ = prepare_inputs_and_outputs( self, train_columns={("input_ids", "Train Input"): train_prompts}, validation_columns={("input_ids", "Validation Input"): validation_prompts}, @@ -768,7 +769,9 @@ def _train( # type:ignore[override] # noqa: C901 model=model, tokenizer=self.tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics, + compute_metrics=wrap_compute_metrics( + compute_metrics=compute_metrics, training_args=training_args + ), callbacks=callbacks, preprocess_logits_for_metrics=preprocess_logits_for_metrics, args=training_args, @@ -776,7 +779,7 @@ def _train( # type:ignore[override] # noqa: C901 trainer.remove_callback(PrinterCallback) # Start the trainer - _start_hf_trainer(self, trainer) + start_hf_trainer(self, trainer) # Save the model to disk self._save_model( diff --git a/src/trainers/train_hf_reward_model.py b/src/trainers/train_hf_reward_model.py index c7dd9877..197b47a4 100644 --- a/src/trainers/train_hf_reward_model.py +++ b/src/trainers/train_hf_reward_model.py @@ -9,22 +9,24 @@ from torch.nn import functional as F from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn +from ..trainers.trainer import JointMetric from ..utils.arg_utils import AUTO, Default from ..utils.device_utils import _TrainingArgumentDeviceOverrideMixin from ..utils.distributed_utils import not_distributed_or_main_process from ..utils.hf_model_utils import get_base_model_from_peft_model -from ..utils.import_utils import ignore_transformers_warnings, ignore_trl_warnings -from ._train_hf_base import ( +from ..utils.hf_training_utils import ( CustomDataCollatorWithPadding, TrainingArguments, - _prepare_inputs_and_outputs, - _start_hf_trainer, - _TrainHFBase, - _wrap_trainer_cls, + _monkey_patch_TrainerState__post_init__, get_logging_callback, + prepare_inputs_and_outputs, + start_hf_trainer, + wrap_compute_metrics, + wrap_trainer_cls, ) +from ..utils.import_utils import ignore_transformers_warnings, ignore_trl_warnings +from ._train_hf_base import _TrainHFBase from .train_hf_classifier import TrainHFClassifier -from .trainer import JointMetric, _monkey_patch_TrainerState__post_init__ with ignore_transformers_warnings(): from transformers import EarlyStoppingCallback, PreTrainedModel @@ -186,7 +188,7 @@ def _train_with_pairs( ): validation_rejected_scores, } ) - train_dataset, validation_dataset, _, _ = _prepare_inputs_and_outputs( + train_dataset, validation_dataset, _, _ = prepare_inputs_and_outputs( self, train_columns=train_columns, validation_columns=validation_columns, @@ -309,6 +311,7 @@ class RewardConfig(_TrainingArgumentDeviceOverrideMixin, _RewardConfig): weight_decay=weight_decay, lr_scheduler_type=lr_scheduler_type, warmup_steps=warmup_steps, + eval_accumulation_steps=kwargs.pop("eval_accumulation_steps", 1), logging_strategy=kwargs.pop("logging_strategy", None) or "steps", logging_steps=kwargs.pop("logging_steps", 1), evaluation_strategy=kwargs.pop("evaluation_strategy", None) or "epoch", @@ -325,7 +328,7 @@ class RewardConfig(_TrainingArgumentDeviceOverrideMixin, _RewardConfig): ) # Setup trainer - trainer = _wrap_trainer_cls( + trainer = wrap_trainer_cls( trainer_cls=trainer_cls or RewardTrainer, **trainer_override_kwargs )( train_dataset=train_dataset, @@ -333,7 +336,9 @@ class RewardConfig(_TrainingArgumentDeviceOverrideMixin, _RewardConfig): model=model, tokenizer=self.tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics, + compute_metrics=wrap_compute_metrics( + compute_metrics=compute_metrics, training_args=training_args + ), callbacks=callbacks, preprocess_logits_for_metrics=preprocess_logits_for_metrics, args=training_args, @@ -343,7 +348,7 @@ class RewardConfig(_TrainingArgumentDeviceOverrideMixin, _RewardConfig): trainer.remove_callback(PrinterCallback) # Start the trainer - _start_hf_trainer(self, trainer) + start_hf_trainer(self, trainer) # Save the model to disk self._save_model( @@ -402,7 +407,7 @@ def _train_with_scores( assert ( self._is_encoder_decoder or truncate ), "`truncate=False` is not supported for this model." - train_dataset, validation_dataset, _, _ = _prepare_inputs_and_outputs( + train_dataset, validation_dataset, _, _ = prepare_inputs_and_outputs( self, train_columns={ ("train_input", "Train Prompts"): train_prompts, @@ -504,6 +509,7 @@ def compute_mse_metrics(eval_pred): weight_decay=weight_decay, lr_scheduler_type=lr_scheduler_type, warmup_steps=warmup_steps, + eval_accumulation_steps=kwargs.pop("eval_accumulation_steps", 1), logging_strategy=kwargs.pop("logging_strategy", None) or "steps", logging_steps=kwargs.pop("logging_steps", 1), evaluation_strategy=kwargs.pop("evaluation_strategy", None) or "epoch", @@ -520,7 +526,7 @@ def compute_mse_metrics(eval_pred): ) # Setup trainer - trainer = _wrap_trainer_cls( + trainer = wrap_trainer_cls( trainer_cls=trainer_cls or Trainer, **trainer_override_kwargs )( train_dataset=train_dataset, @@ -528,7 +534,9 @@ def compute_mse_metrics(eval_pred): model=model, tokenizer=self.tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics, + compute_metrics=wrap_compute_metrics( + compute_metrics=compute_metrics, training_args=training_args + ), callbacks=callbacks, preprocess_logits_for_metrics=preprocess_logits_for_metrics, args=training_args, @@ -536,7 +544,7 @@ def compute_mse_metrics(eval_pred): trainer.remove_callback(PrinterCallback) # Start the trainer - _start_hf_trainer(self, trainer) + start_hf_trainer(self, trainer) # Save the model to disk self._save_model( diff --git a/src/trainers/train_sentence_transformer.py b/src/trainers/train_sentence_transformer.py index cdc0ab40..75c0ee4e 100644 --- a/src/trainers/train_sentence_transformer.py +++ b/src/trainers/train_sentence_transformer.py @@ -13,26 +13,29 @@ from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn from ..embedders.sentence_transformers_embedder import _normalize_model_name +from ..trainers.trainer import JointMetric from ..utils.arg_utils import AUTO, DEFAULT, Default, default_to from ..utils.background_utils import RunIfTimeout from ..utils.hf_model_utils import ( + filter_model_warnings, get_base_model_from_peft_model, get_model_max_context_length, get_tokenizer, validate_peft_config, ) -from ..utils.import_utils import ignore_transformers_warnings -from ._train_hf_base import ( +from ..utils.hf_training_utils import ( CustomDataCollatorWithPadding, TrainingArguments, - _prepare_inputs_and_outputs, - _start_hf_trainer, - _TrainHFBase, - _wrap_trainer_cls, + _monkey_patch_TrainerState__post_init__, get_logging_callback, + prepare_inputs_and_outputs, + start_hf_trainer, + wrap_compute_metrics, + wrap_trainer_cls, ) +from ..utils.import_utils import ignore_transformers_warnings +from ._train_hf_base import _TrainHFBase from ._vendored import _sentence_transformer_helper -from .trainer import JointMetric, _monkey_patch_TrainerState__post_init__ with ignore_transformers_warnings(): from sentence_transformers import SentenceTransformer, losses @@ -263,12 +266,17 @@ def _create_model( from peft import get_peft_model, prepare_model_for_kbit_training if self.quantization_config: # pragma: no cover - model = prepare_model_for_kbit_training(model) + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=True + ) model = get_peft_model(model, validate_peft_config(model, self.peft_config)) # Switch model to train mode model.train() + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial(lambda self: self.logger.info("Finished loading."), self) @@ -403,7 +411,7 @@ def _train( # type:ignore[override] # noqa: C901 ] = validation_negatives if has_labels and validation_labels is not None: validation_columns[("labels", "Validation Labels")] = validation_labels - train_dataset, validation_dataset, _, _ = _prepare_inputs_and_outputs( + train_dataset, validation_dataset, _, _ = prepare_inputs_and_outputs( self, train_columns=train_columns, validation_columns=validation_columns, @@ -595,6 +603,7 @@ def __getattr__(self, name): weight_decay=weight_decay, lr_scheduler_type=lr_scheduler_type, warmup_steps=warmup_steps, + eval_accumulation_steps=kwargs.pop("eval_accumulation_steps", 1), logging_strategy=kwargs.pop("logging_strategy", None) or "steps", logging_steps=kwargs.pop("logging_steps", 1), evaluation_strategy=kwargs.pop("evaluation_strategy", None) or "epoch", @@ -613,7 +622,7 @@ def __getattr__(self, name): ) # Setup trainer - trainer = _wrap_trainer_cls( + trainer = wrap_trainer_cls( trainer_cls=trainer_cls or Trainer, **trainer_override_kwargs )( train_dataset=train_dataset, @@ -621,7 +630,9 @@ def __getattr__(self, name): model=wrapped_model_with_loss, tokenizer=self.tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics, + compute_metrics=wrap_compute_metrics( + compute_metrics=compute_metrics, training_args=training_args + ), callbacks=callbacks, preprocess_logits_for_metrics=preprocess_logits_for_metrics, args=training_args, @@ -629,7 +640,7 @@ def __getattr__(self, name): trainer.remove_callback(PrinterCallback) # Start the trainer - _start_hf_trainer(self, trainer) + start_hf_trainer(self, trainer) # Save the model to disk self._save_model( @@ -855,6 +866,9 @@ def _load_model( # model = torch.compile(model) pass + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial(lambda self: self.logger.info("Finished loading."), self) diff --git a/src/trainers/train_setfit_classifier.py b/src/trainers/train_setfit_classifier.py index 4191802d..7812fc32 100644 --- a/src/trainers/train_setfit_classifier.py +++ b/src/trainers/train_setfit_classifier.py @@ -15,14 +15,18 @@ from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn from ..utils.arg_utils import AUTO, DEFAULT, Default, default_to from ..utils.background_utils import RunIfTimeout -from ..utils.hf_model_utils import get_base_model_from_peft_model, validate_peft_config -from ..utils.import_utils import ignore_transformers_warnings -from ._train_hf_base import ( - _prepare_inputs_and_outputs, - _start_hf_trainer, - _TrainHFBase, +from ..utils.hf_model_utils import ( + filter_model_warnings, + get_base_model_from_peft_model, + validate_peft_config, +) +from ..utils.hf_training_utils import ( get_logging_callback, + prepare_inputs_and_outputs, + start_hf_trainer, ) +from ..utils.import_utils import ignore_transformers_warnings +from ._train_hf_base import _TrainHFBase from ._vendored._setfit_helper import get_peft_model_cls # type:ignore[attr-defined] from .train_hf_classifier import TrainHFClassifier from .train_sentence_transformer import TrainSentenceTransformer @@ -198,12 +202,17 @@ def _create_model( from peft import prepare_model_for_kbit_training if self.quantization_config: # pragma: no cover - model.model_body = prepare_model_for_kbit_training(model.model_body) + model.model_body = prepare_model_for_kbit_training( + model.model_body, use_gradient_checkpointing=True + ) model.model_body = get_peft_model_cls()( model=model.model_body, peft_config=validate_peft_config(model.model_body, self.peft_config), ) + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial(lambda self: self.logger.info("Finished loading."), self) @@ -249,7 +258,7 @@ def _train( # type:ignore[override] validation_dataset, label2id, is_multi_target, - ) = _prepare_inputs_and_outputs( + ) = prepare_inputs_and_outputs( self, train_columns={ ("text", "Train Input"): train_input, @@ -409,7 +418,7 @@ def evaluate(self, *args, **kwargs): # Start the trainer self.logger.info("Training SetFit model body (embeddings)...") - _start_hf_trainer(self, trainer) + start_hf_trainer(self, trainer) self.logger.info("Running final trained SetFit model evaluation...") trainer.evaluate(final=True) # Run a final evaluation @@ -579,6 +588,9 @@ def _load_model( # model = torch.compile(model) pass + # Filter any warnings from the model + filter_model_warnings() + # Finished loading log_if_timeout.stop( partial(lambda self: self.logger.info("Finished loading."), self) diff --git a/src/trainers/trainer.py b/src/trainers/trainer.py index 2932ef32..76cbb37d 100644 --- a/src/trainers/trainer.py +++ b/src/trainers/trainer.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from functools import cache, cached_property, partial, total_ordering +from functools import cached_property, partial, total_ordering from logging import Logger from typing import Any @@ -17,10 +17,7 @@ from ..steps.data_card import DataCardType, sort_data_card from ..utils.distributed_utils import run_distributed from ..utils.fs_utils import clear_dir, mkdir, move_dir, safe_fn -from ..utils.import_utils import ignore_training_warnings, ignore_transformers_warnings - -with ignore_transformers_warnings(): - from transformers import TrainerState +from ..utils.import_utils import ignore_training_warnings class ModelNoLongerExistsError(Exception): @@ -60,24 +57,6 @@ def __str_(self) -> str: # pragma: no cover return self.__repr__() -_old_TrainerState__post_init__ = TrainerState.__post_init__ - - -def _deserialize_join_metric__TrainerState__post_init__(self, *args, **kwargs): - _old_TrainerState__post_init__(self, *args, **kwargs) - if ( - hasattr(self, "best_metric") - and isinstance(self.best_metric, dict) - and "is_joint_metric" in self.best_metric - ): - self.best_metric = JointMetric(**self.best_metric) - - -@cache -def _monkey_patch_TrainerState__post_init__(): - TrainerState.__post_init__ = _deserialize_join_metric__TrainerState__post_init__ - - class Trainer(ABC): _trainer_tags = ["datadreamer"] diff --git a/src/utils/device_utils.py b/src/utils/device_utils.py index 6afad75f..f8188e63 100644 --- a/src/utils/device_utils.py +++ b/src/utils/device_utils.py @@ -89,6 +89,7 @@ def get_device_env_variables(devices: list[int | str | torch.device]) -> dict[st len(true_device_ids) == len(devices) ), f"The device list you specified ({devices}) is invalid (or devices could not be found)." device_env = {"CUDA_VISIBLE_DEVICES": ",".join(map(str, true_device_ids))} + device_env["NCCL_P2P_DISABLE"] = "1" return device_env diff --git a/src/utils/distributed_utils.py b/src/utils/distributed_utils.py index 9ef9adb2..a2c0c5b6 100644 --- a/src/utils/distributed_utils.py +++ b/src/utils/distributed_utils.py @@ -29,7 +29,9 @@ from transformers import PreTrainedModel -def default_fsdp_config(model: PreTrainedModel) -> dict[str, Any]: # pragma: no cover +def default_fsdp_config( + model: PreTrainedModel, kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: # pragma: no cover if is_peft_model(model): model = get_base_model_from_peft_model(model) if isinstance(model, SentenceTransformer): @@ -47,13 +49,24 @@ def default_fsdp_config(model: PreTrainedModel) -> dict[str, Any]: # pragma: no transformer_layer_cls_to_wrap = list( filter(lambda x: x in named_modules, transformer_layer_cls_to_wrap) ) - return { - "fsdp": "full_shard auto_wrap", - "fsdp_config": { - "backward_prefetch": "backward_pre", - "transformer_layer_cls_to_wrap": transformer_layer_cls_to_wrap, + optional_fsdp_config = {} + if kwargs.get("gradient_checkpointing", False): + del kwargs["gradient_checkpointing"] + optional_fsdp_config["activation_checkpointing"] = "true" + os.environ["FSDP_OFFLOAD_PARAMS"] = "1" + return ( + { + "fsdp": "full_shard auto_wrap", + "fsdp_config": { + "backward_prefetch": "backward_pre", + "transformer_layer_cls_to_wrap": transformer_layer_cls_to_wrap, + "cpu_ram_efficient_loading": "true", + "sync_module_states": "true", + **optional_fsdp_config, + }, }, - } + kwargs, + ) def apply_distributed_config(self, kwargs: dict[str, Any]) -> dict[str, Any]: @@ -61,10 +74,10 @@ def apply_distributed_config(self, kwargs: dict[str, Any]) -> dict[str, Any]: _device = kwargs.pop("_device") self._selected_device = _device _model = kwargs.pop("_model") - default_fsdp_kwargs = ( - default_fsdp_config(_model) + default_fsdp_kwargs, kwargs = ( + default_fsdp_config(model=_model, kwargs=kwargs) if isinstance(_device, list) - else {"fsdp": "", "fsdp_config": None} + else ({"fsdp": "", "fsdp_config": None}, kwargs) ) fsdp = default_to(kwargs.pop("fsdp"), default_fsdp_kwargs["fsdp"]) fsdp_is_enabled = ( @@ -225,7 +238,7 @@ def configure_and_launch(): # Create a communication pipe spawn_context = get_context(method="spawn") - pipe: Any = spawn_context.Queue(2) + pipe: Any = spawn_context.Queue(nproc_per_node * 100) # Launch the spawned child processes (share the parent context with them) if final_logger.level > logging.DEBUG: diff --git a/src/utils/hf_model_utils.py b/src/utils/hf_model_utils.py index 0a89f3d0..95e69bbc 100644 --- a/src/utils/hf_model_utils.py +++ b/src/utils/hf_model_utils.py @@ -1,3 +1,5 @@ +import logging +import sys from copy import copy from functools import cache from typing import Any @@ -316,3 +318,23 @@ def get_model_optional_kwargs( if quantization_config is not None: # pragma: no cover optional_kwargs["quantization_config"] = quantization_config return optional_kwargs + + +def filter_model_warnings(): + # Filter warning logs + for model_logger_name in [ + n + for n in sys.modules.keys() + if ( + n.startswith("transformers.models.") and "modeling" in n and "auto" not in n + ) + ]: + model_logger = logging.getLogger(model_logger_name) + + class NoUseCacheIsIncompatibleWarningFilter(logging.Filter): # pragma: no cover + def filter(self, record): + return not record.getMessage().startswith( + "`use_cache=True` is incompatible with gradient checkpointing" + ) + + model_logger.addFilter(NoUseCacheIsIncompatibleWarningFilter()) diff --git a/src/utils/hf_training_utils.py b/src/utils/hf_training_utils.py new file mode 100644 index 00000000..0cc3b3f1 --- /dev/null +++ b/src/utils/hf_training_utils.py @@ -0,0 +1,973 @@ +import logging +import os +import sys +from functools import cache, partial +from itertools import chain +from typing import TYPE_CHECKING, Any, Callable, Type, cast + +import dill +import numpy as np +import torch +from datasets import Dataset, IterableDataset, Value, concatenate_datasets +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from .. import DataDreamer +from ..datasets import ( + OutputDatasetColumn, + OutputIterableDataset, + OutputIterableDatasetColumn, +) +from ..datasets.datasets import _SizedIterableDataset, get_sized_dataset +from ..steps import DataSource +from ..trainers.trainer import JointMetric +from ..utils.device_utils import ( + _TrainingArgumentDeviceOverrideMixin, + get_device_memory_monitoring_callback, +) +from ..utils.distributed_utils import ( + get_global_rank, + get_local_world_size, + is_distributed, + not_distributed_or_main_process, +) +from ..utils.import_utils import ignore_transformers_warnings + +with ignore_transformers_warnings(): + from setfit import logging as setfit_logging + from transformers import ( + PreTrainedTokenizer, + TrainerCallback, + TrainerState, + logging as hf_transformers_logging, + ) + +from transformers import ( + Seq2SeqTrainingArguments as _Seq2SeqTrainingArguments, + TrainingArguments as _TrainingArguments, +) + +if TYPE_CHECKING: # pragma: no cover + from ..trainers.train_hf_classifier import _TrainHFBase + + with ignore_transformers_warnings(): + from transformers import Trainer + + +_old_TrainerState__post_init__ = TrainerState.__post_init__ + + +def _deserialize_join_metric__TrainerState__post_init__(self, *args, **kwargs): + _old_TrainerState__post_init__(self, *args, **kwargs) + if ( + hasattr(self, "best_metric") + and isinstance(self.best_metric, dict) + and "is_joint_metric" in self.best_metric + ): + self.best_metric = JointMetric(**self.best_metric) + + +@cache +def _monkey_patch_TrainerState__post_init__(): + TrainerState.__post_init__ = _deserialize_join_metric__TrainerState__post_init__ + + +class TrainingArguments(_TrainingArgumentDeviceOverrideMixin, _TrainingArguments): + pass + + +class Seq2SeqTrainingArguments( + _TrainingArgumentDeviceOverrideMixin, _Seq2SeqTrainingArguments +): + pass + + +def wrap_trainer_cls( + trainer_cls: Type["Trainer"], + optimizers: tuple[None | Optimizer, None | LambdaLR] = (None, None), + optimizer: None | Optimizer = None, + lr_scheduler: None | LambdaLR = None, + compute_loss: None | Callable = None, +) -> Type["Trainer"]: + class WrappedTrainer(trainer_cls): + def create_optimizer(self): + final_optimizer = optimizer or optimizers[0] + if final_optimizer is not None: # pragma: no cover + self.optimizer = final_optimizer + else: + super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: None | Optimizer = None + ): + final_lr_scheduler = lr_scheduler or optimizers[1] + if final_lr_scheduler is not None: # pragma: no cover + self.lr_scheduler = final_lr_scheduler + else: + super().create_scheduler( + num_training_steps=num_training_steps, optimizer=optimizer + ) + + def compute_loss(self, model, inputs, return_outputs=False): + if compute_loss is not None: # pragma: no cover + return compute_loss(model, inputs, return_outputs=return_outputs) + else: + return super().compute_loss( + model, inputs, return_outputs=return_outputs + ) + + return WrappedTrainer + + +def prepare_inputs_and_outputs( # noqa: C901 + self: "_TrainHFBase", + train_columns: dict[ + tuple[str, str], OutputDatasetColumn | OutputIterableDatasetColumn + ], + validation_columns: dict[ + tuple[str, str], OutputDatasetColumn | OutputIterableDatasetColumn + ], + truncate: bool = False, + causal: bool = False, + dpo: bool = False, + reward_pairs: bool = False, + reward_scores: bool = False, +) -> tuple[ + Dataset | IterableDataset | _SizedIterableDataset, + Dataset | IterableDataset | _SizedIterableDataset, + dict[Any, int], + bool, +]: + num_proc = ( + ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count() + ) + if sys.platform != "darwin" + else 1 + ) + label2id: dict[Any, int] = {} + is_multi_target: bool = False + + def get_train_column( + column_name: str, + ) -> OutputDatasetColumn | OutputIterableDatasetColumn: + for (train_column_name, _), train_column in train_columns.items(): + if train_column_name == column_name: + return train_column + raise KeyError(f"Train column {column_name} not found.") # pragma: no cover + + def get_validation_column( + column_name: str, + ) -> OutputDatasetColumn | OutputIterableDatasetColumn: + for ( + validation_column_name, + _, + ), validation_column in validation_columns.items(): + if validation_column_name == column_name: + return validation_column + raise KeyError( + f"Validation column {column_name} not found." + ) # pragma: no cover + + def apply_chat_prompt_template(prompt: str) -> str: + return ( + cast(str, self.chat_prompt_template) + .replace("{{system_prompt}}", self.system_prompt or "") + .replace("{{prompt}}", prompt) + ) + + def tokenize_function( + examples, + column_name: str, + new_column_name: str, + causal: bool, + reward_scores: bool, + ): # pragma: no cover + if reward_scores: + prompt, completion = examples[column_name] + if self.chat_prompt_template: + prompt = apply_chat_prompt_template(prompt) + input_ids = self.tokenizer( + prompt + completion, + truncation=truncate, + padding=False, + add_special_tokens=True, + )["input_ids"] + return { + "input_ids": input_ids[: self.tokenizer.model_max_length] + if truncate + else input_ids, + "labels": examples["label"], + } + elif causal: + prompt, completion = examples[column_name] + if self.chat_prompt_template: + prompt = apply_chat_prompt_template(prompt) + prompt_input_ids = self.tokenizer( + prompt, truncation=truncate, padding=False, add_special_tokens=True + )["input_ids"] + completion_input_ids = self.tokenizer( + completion, truncation=truncate, padding=False, add_special_tokens=False + )["input_ids"] + [self.tokenizer.eos_token_id] + prompt_labels = [-100] * len(prompt_input_ids) + input_ids = prompt_input_ids + completion_input_ids + labels = prompt_labels + completion_input_ids + return { + "input_ids": input_ids[: self.tokenizer.model_max_length] + if truncate + else input_ids, + "labels": labels[: self.tokenizer.model_max_length] + if truncate + else labels, + } + elif new_column_name in ["decoder_labels"]: + return { + "labels": self.tokenizer( + examples[column_name], + truncation=truncate, + padding=False, + add_special_tokens=True, + )["input_ids"] + } + else: + prompts = examples[column_name] + if self.chat_prompt_template: + prompts = list(map(apply_chat_prompt_template, prompts)) + tokenizer_results = self.tokenizer( + prompts, truncation=truncate, padding=False, add_special_tokens=True + ) + return { + new_column_name: tokenizer_results["input_ids"], + f"{new_column_name.replace('input_ids', '')}attention_mask": tokenizer_results[ + "attention_mask" + ], + } + + def tokenize_column_name( + column_name: str, + new_column_name: str, + causal: bool, + reward_scores: bool = False, + ) -> Callable: + return partial( + tokenize_function, + column_name=column_name, + new_column_name=new_column_name, + causal=causal, + reward_scores=reward_scores, + ) + + def tokenize_column( + column: OutputDatasetColumn | OutputIterableDatasetColumn, + new_column_name: str, + name: str, + causal: bool = False, + reward_scores: bool = False, + ) -> Dataset | IterableDataset: + column_name = column.column_names[0] + return column.step.map( + name=f"Tokenize {name}", + function=tokenize_column_name( + column_name, + new_column_name=new_column_name, + causal=causal, + reward_scores=reward_scores, + ), + batched=not causal and not reward_scores, + remove_columns=column.step.output.column_names, + total_num_rows=column.num_rows, + auto_progress=column.num_rows is not None, + lazy=isinstance(column, OutputIterableDatasetColumn), + progress_interval=sys.maxsize + if isinstance(column, OutputIterableDatasetColumn) + else 120, + save_num_proc=num_proc, + ).output.dataset + + def rename_column( + column: OutputDatasetColumn | OutputIterableDatasetColumn, new_column_name: str + ) -> Dataset | IterableDataset: + column_name = column.column_names[0] + column_dataset = column.step.output.dataset.select_columns(column.column_names) + return ( + column_dataset.rename_column(column_name, new_column_name) + if column_name != new_column_name + else column_dataset + ) + + def label_encode_function( + _, column_name: str, example: dict[str, Any] + ) -> dict[str, Any]: # pragma: no cover + if isinstance(example[column_name], list): + row_labels = set(str(label) for label in example[column_name]) + return { + column_name: [1 if label in row_labels else 0 for label in label2id] + } + else: + return {column_name: label2id[str(example[column_name])]} + + def label2id_column( + column: OutputDatasetColumn | OutputIterableDatasetColumn, + new_column_name: str, + name: str, + ) -> Dataset | IterableDataset: + column_name = column.column_names[0] + return rename_column( + column.step.map( + name=f"Encode {name} labels", + function=partial( + label_encode_function, sorted(label2id.keys()), column_name + ), + batched=False, + remove_columns=list( + set(column.step.output.column_names).difference(set([column_name])) + ), + total_num_rows=column.num_rows, + auto_progress=column.num_rows is not None, + lazy=isinstance(column, OutputIterableDatasetColumn), + progress_interval=sys.maxsize + if isinstance(column, OutputIterableDatasetColumn) + else 120, + save_num_proc=num_proc, + ).output[column_name], + new_column_name, + ) + + def process_column( + column: OutputDatasetColumn | OutputIterableDatasetColumn, + new_column_name: str, + name: str, + ) -> Dataset | IterableDataset: + if new_column_name == "label" and reward_scores is False: + return label2id_column( + column=column, new_column_name=new_column_name, name=name + ) + else: # pragma: no cover + return rename_column(column=column, new_column_name=new_column_name) + + def concatenate_prompts_and_completions( + dataset: Dataset | IterableDataset, + ) -> IterableDataset: + iterable_dataset = ( + dataset.to_iterable_dataset() if isinstance(dataset, Dataset) else dataset + ) + return iterable_dataset.map( + lambda row: {"text": [row["prompt"], row["completion"]]}, + remove_columns=["prompt", "completion"], + ) + + # Calculate label2id + uniq_labels = [] + for (new_column_name, name), column in list(train_columns.items()) + list( + validation_columns.items() + ): + column_name = column.column_names[0] + + def uniqify_labels(labels: set[Any], column_name, example): + nonlocal is_multi_target + if isinstance(example[column_name], list): + is_multi_target = True + is_new = False + for label in example[column_name]: + if label not in labels: + is_new = True + labels.add(label) + return is_new + else: + is_new = example[column_name] not in labels + labels.add(example[column_name]) + return is_new + + if new_column_name == "label" and reward_scores is False: + uniq_labels_column = column.step.filter( + name=f"Get all {name} label names", + function=partial(uniqify_labels, set(), column_name), + batched=False, + total_num_rows=column.num_rows, + auto_progress=column.num_rows is not None, + lazy=False, + progress_interval=sys.maxsize + if isinstance(column, OutputIterableDatasetColumn) + else 120, + ).output[column_name] + uniq_labels_from_column = list(uniq_labels_column) + uniq_labels += ( + list(chain.from_iterable(uniq_labels_column)) + if len(uniq_labels_from_column) > 0 + and isinstance(uniq_labels_from_column[0], list) + else uniq_labels_column + ) + uniq_labels = sorted(set(uniq_labels)) + for label in uniq_labels: + label2id[str(label)] = len(label2id) + + # Create train and validation datasets + train_dataset: Dataset | IterableDataset + validation_dataset: Dataset | IterableDataset + if reward_pairs: + # Check if scores are provided + try: + get_train_column("train_chosen_scores") + has_scores = True + except KeyError: + has_scores = False + + # Get data collator + def prepare_for_reward_pairs(row): # pragma: no cover + row = row.copy() + if self.chat_prompt_template: + row["prompt"] = apply_chat_prompt_template(row["prompt"]) + row["chosen"] = row["prompt"] + row["chosen"] + row["rejected"] = row["prompt"] + row["rejected"] + reward_results = {} + chosen_tokenizer_results = self.tokenizer( + row["chosen"], + truncation=truncate, + padding=False, + add_special_tokens=True, + ) + reward_results["input_ids_chosen"] = chosen_tokenizer_results["input_ids"] + rejected_tokenizer_results = self.tokenizer( + row["rejected"], + truncation=truncate, + padding=False, + add_special_tokens=True, + ) + reward_results["input_ids_rejected"] = rejected_tokenizer_results[ + "input_ids" + ] + if "chosen_scores" in row and "rejected_scores" in row: + reward_results["margin"] = row["chosen_scores"] - row["rejected_scores"] + return reward_results + + # Run data collator + train_columns_to_combine = [ + rename_column(get_train_column("train_prompts"), "prompt"), + rename_column(get_train_column("train_chosen"), "chosen"), + rename_column(get_train_column("train_rejected"), "rejected"), + ] + if has_scores: + train_columns_to_combine.extend( + [ + rename_column( + get_train_column("train_chosen_scores"), "chosen_scores" + ), + rename_column( + get_train_column("train_rejected_scores"), "rejected_scores" + ), + ] + ) + train_combine_step = DataSource( + "Combine Train Prompts, Chosen Generations, and Rejected Generations", + data=concatenate_datasets(train_columns_to_combine, axis=1), + total_num_rows=get_train_column("train_prompts").num_rows, + auto_progress=get_train_column("train_prompts").num_rows is not None, + ) + train_dataset = train_combine_step.map( + name="Prepare Train Dataset for Reward Model Training", + function=prepare_for_reward_pairs, + batched=False, + remove_columns=train_combine_step.output.column_names, + total_num_rows=get_train_column("train_prompts").num_rows, + auto_progress=get_train_column("train_prompts").num_rows is not None, + lazy=isinstance(train_combine_step.output, OutputIterableDataset), + progress_interval=sys.maxsize + if isinstance(train_combine_step.output, OutputIterableDataset) + else 120, + save_num_proc=num_proc, + ).output.dataset + validation_columns_to_combine = [ + rename_column(get_validation_column("validation_prompts"), "prompt"), + rename_column(get_validation_column("validation_chosen"), "chosen"), + rename_column(get_validation_column("validation_rejected"), "rejected"), + ] + if has_scores: + validation_columns_to_combine.extend( + [ + rename_column( + get_validation_column("validation_chosen_scores"), + "chosen_scores", + ), + rename_column( + get_validation_column("validation_rejected_scores"), + "rejected_scores", + ), + ] + ) + validation_combine_step = DataSource( + "Combine Validation Prompts, Chosen Generations, and Rejected Generations", + data=concatenate_datasets(validation_columns_to_combine, axis=1), + total_num_rows=get_validation_column("validation_prompts").num_rows, + auto_progress=get_validation_column("validation_prompts").num_rows + is not None, + ) + validation_dataset = validation_combine_step.map( + name="Prepare Validation Dataset for Reward Model Training", + function=prepare_for_reward_pairs, + batched=False, + remove_columns=validation_combine_step.output.column_names, + total_num_rows=get_validation_column("validation_prompts").num_rows, + auto_progress=get_validation_column("validation_prompts").num_rows + is not None, + lazy=isinstance(validation_combine_step.output, OutputIterableDataset), + progress_interval=sys.maxsize + if isinstance(validation_combine_step.output, OutputIterableDataset) + else 120, + save_num_proc=num_proc, + ).output.dataset + elif dpo: + if TYPE_CHECKING: # pragma: no cover + DPODataCollatorWithPadding: Any = None + else: + from ..trainers._vendored._dpo_helper import DPODataCollatorWithPadding + + # Get data collator + data_collator = DPODataCollatorWithPadding( + tokenizer=self.tokenizer, + max_length=self.tokenizer.model_max_length if truncate else sys.maxsize, + max_prompt_length=self.tokenizer.model_max_length + if truncate + else sys.maxsize, + label_pad_token_id=-100, + padding_value=0, + truncation_mode="keep_end", + is_encoder_decoder=self._is_encoder_decoder, + max_target_length=self.tokenizer.model_max_length + if truncate + else sys.maxsize, + ) + + def run_data_collator(row): # pragma: no cover + if self.chat_prompt_template: + row["prompt"] = apply_chat_prompt_template(row["prompt"]) + dpo_results = data_collator.__call__([row]) + for key, value in list(dpo_results.items()): + if "attention_mask" in key: + del dpo_results[key] + elif isinstance(value, list) and len(value) == 1: + dpo_results[key] = value[0] + elif isinstance(value, torch.Tensor) and len(value.shape) == 2: + value = value[0] + if truncate: + dpo_results[key] = value[: self.tokenizer.model_max_length] + return dpo_results + + # Run data collator + train_combine_step = DataSource( + "Combine Train Prompts, Chosen Generations, and Rejected Generations", + data=concatenate_datasets( + [ + rename_column(get_train_column("train_prompts"), "prompt"), + rename_column(get_train_column("train_chosen"), "chosen"), + rename_column(get_train_column("train_rejected"), "rejected"), + ], + axis=1, + ), + total_num_rows=get_train_column("train_prompts").num_rows, + auto_progress=get_train_column("train_prompts").num_rows is not None, + ) + train_dataset = train_combine_step.map( + name="Prepare Train Dataset for DPO", + function=run_data_collator, + batched=False, + total_num_rows=get_train_column("train_prompts").num_rows, + auto_progress=get_train_column("train_prompts").num_rows is not None, + lazy=isinstance(train_combine_step.output, OutputIterableDataset), + progress_interval=sys.maxsize + if isinstance(train_combine_step.output, OutputIterableDataset) + else 120, + save_num_proc=num_proc, + ).output.dataset + validation_combine_step = DataSource( + "Combine Validation Prompts, Chosen Generations, and Rejected Generations", + data=concatenate_datasets( + [ + rename_column( + get_validation_column("validation_prompts"), "prompt" + ), + rename_column(get_validation_column("validation_chosen"), "chosen"), + rename_column( + get_validation_column("validation_rejected"), "rejected" + ), + ], + axis=1, + ), + total_num_rows=get_validation_column("validation_prompts").num_rows, + auto_progress=get_validation_column("validation_prompts").num_rows + is not None, + ) + validation_dataset = validation_combine_step.map( + name="Prepare Validation Dataset for DPO", + function=run_data_collator, + batched=False, + total_num_rows=get_validation_column("validation_prompts").num_rows, + auto_progress=get_validation_column("validation_prompts").num_rows + is not None, + lazy=isinstance(validation_combine_step.output, OutputIterableDataset), + progress_interval=sys.maxsize + if isinstance(validation_combine_step.output, OutputIterableDataset) + else 120, + save_num_proc=num_proc, + ).output.dataset + elif reward_scores: + train_combined = concatenate_datasets( + [ + rename_column(get_train_column("train_input"), "prompt"), + rename_column(get_train_column("train_output"), "completion"), + rename_column(get_train_column("label"), "label").cast_column( + "label", Value("float64") + ), + ], + axis=1, + ) + train_dataset = tokenize_column( + DataSource( + "Concatenate Train Prompts and Generations", + data=concatenate_prompts_and_completions(train_combined), + total_num_rows=get_train_column("train_input").num_rows, + auto_progress=get_train_column("train_input").num_rows is not None, + save=not isinstance(train_combined, IterableDataset), + ).output["text"], + "input_ids", + "Train Dataset", + reward_scores=True, + ) + validation_combined = concatenate_datasets( + [ + rename_column(get_validation_column("validation_input"), "prompt"), + rename_column(get_validation_column("validation_output"), "completion"), + rename_column(get_validation_column("label"), "label").cast_column( + "label", Value("float64") + ), + ], + axis=1, + ) + validation_dataset = tokenize_column( + DataSource( + "Concatenate Validation Prompts and Generations", + data=concatenate_prompts_and_completions(validation_combined), + total_num_rows=get_validation_column("validation_input").num_rows, + auto_progress=get_validation_column("validation_input").num_rows + is not None, + save=not isinstance(validation_combined, IterableDataset), + ).output["text"], + "input_ids", + "Validation Dataset", + reward_scores=True, + ) + elif causal: + train_combined = concatenate_datasets( + [ + rename_column(get_train_column("train_input"), "prompt"), + rename_column(get_train_column("train_output"), "completion"), + ], + axis=1, + ) + train_dataset = tokenize_column( + DataSource( + "Concatenate Train Input and Output", + data=concatenate_prompts_and_completions(train_combined), + total_num_rows=get_train_column("train_input").num_rows, + auto_progress=get_train_column("train_input").num_rows is not None, + save=not isinstance(train_combined, IterableDataset), + ).output["text"], + "input_ids", + "Train Dataset", + causal=True, + ) + validation_combined = concatenate_datasets( + [ + rename_column(get_validation_column("validation_input"), "prompt"), + rename_column(get_validation_column("validation_output"), "completion"), + ], + axis=1, + ) + validation_dataset = tokenize_column( + DataSource( + "Concatenate Validation Input and Output", + data=concatenate_prompts_and_completions(validation_combined), + total_num_rows=get_validation_column("validation_input").num_rows, + auto_progress=get_validation_column("validation_input").num_rows + is not None, + save=not isinstance(validation_combined, IterableDataset), + ).output["text"], + "input_ids", + "Validation Dataset", + causal=True, + ) + else: + train_dataset = concatenate_datasets( + [ + tokenize_column(train_column, train_column_name, name) + if train_column_name in ["input_ids", "decoder_labels"] + or train_column_name.endswith("_input_ids") + else process_column(train_column, train_column_name, name) + for (train_column_name, name), train_column in train_columns.items() + ], + axis=1, + ) + validation_dataset = concatenate_datasets( + [ + tokenize_column(validation_column, validation_column_name, name) + if validation_column_name in ["input_ids", "decoder_labels"] + or validation_column_name.endswith("_input_ids") + else process_column(validation_column, validation_column_name, name) + for ( + validation_column_name, + name, + ), validation_column in validation_columns.items() + ], + axis=1, + ) + + # Save information for publishing + train_step = list(train_columns.values())[0].step + self._step_metadata = train_step._get_metadata(train_step.output) + + # Save information for publishing + self._examples = { + name: ( + train_column.dataset[:3][train_column.column_names[0]] + if isinstance(train_column.dataset, Dataset) + else list( + map( + lambda row: row[train_column.column_names[0]], + train_column.dataset.take(3), + ) + ) + ) + for (_, name), train_column in train_columns.items() + } + if reward_scores: + if self.chat_prompt_template: + self._examples["Train Prompts"] = [ + apply_chat_prompt_template(prompt) + for prompt in self._examples["Train Prompts"] + ] + self._examples["Train Input"] = [ + prompt + generation + for prompt, generation in zip( + self._examples["Train Prompts"], self._examples["Train Generations"] + ) + ] + elif reward_pairs: + if self.chat_prompt_template: + self._examples["Train Prompts"] = [ + apply_chat_prompt_template(prompt) + for prompt in self._examples["Train Prompts"] + ] + chosen_examples = [ + prompt + generation + for prompt, generation in zip( + self._examples["Train Prompts"], + self._examples["Train Chosen Generations"], + ) + ] + rejected_examples = [ + prompt + generation + for prompt, generation in zip( + self._examples["Train Prompts"], + self._examples["Train Rejected Generations"], + ) + ] + self._examples["Train Input"] = list( + chain.from_iterable(zip(chosen_examples, rejected_examples)) + ) + elif dpo: + self._examples["Train Input"] = self._examples["Train Prompts"] + + # Return datasets + return ( + get_sized_dataset( + dataset=train_dataset, + total_num_rows=list(train_columns.values())[0].num_rows, + ), + get_sized_dataset( + dataset=validation_dataset, + total_num_rows=list(validation_columns.values())[0].num_rows, + ), + label2id, + is_multi_target, + ) + + +def start_hf_trainer(self: "_TrainHFBase", trainer: Any): # noqa: C901 + # Setup loggers the way we need them to be + if not DataDreamer.ctx.hf_log: + if self.logger.level <= logging.NOTSET: # pragma: no cover + hf_transformers_trainer_logger = logging.getLogger("transformers.trainer") + if ( + not hf_transformers_trainer_logger.level + or hf_transformers_trainer_logger.level > logging.INFO + ): + hf_transformers_trainer_logger.level = logging.INFO + hf_transformers_trainer_logger.propagate = True + DataDreamer._enable_hf_transformers_logging(progress_bars=False) + DataDreamer._enable_setfit_logging(progress_bars=False) + hf_transformers_logging.set_verbosity_info() + setfit_logging.set_verbosity_info() + + # Add GPU monitoring if distributed + device_memory_monitoring_callback = get_device_memory_monitoring_callback( + trainer=self + ) + trainer.add_callback(device_memory_monitoring_callback) + + # Run training + try: + # Try to resume + if self.resumable: + trainer.train(resume_from_checkpoint=True) + else: + raise ValueError() + except ValueError: + try: + # Nothing to resume from, so start a new training run + + # Evaluate before starting training so we can see how the model + # performs before any weight updates + if device_memory_monitoring_callback: + device_memory_monitoring_callback()._log_device_memory_usage() + if is_distributed() and trainer.is_fsdp_enabled: # pragma: no cover + from transformers.trainer import logger as trainer_logger + + # This is a hack to run .evaluate() before training happens on FSDP + # but after the FSDP is set up + old_info = trainer_logger.info + + def _info(old_info, *args, **kwargs): + if len(args) > 0 and args[0].startswith( + "***** Running training *****" + ): + trainer.evaluate() + trainer.model.train() # Switch the model back to train mode + trainer_logger.info = old_info # Undo the monkey-patch + return old_info(*args, **kwargs) + + trainer_logger.info = partial(_info, old_info) + else: + trainer.evaluate() + + # Start training + trainer.train() + except Exception as e: + raise e from None + if not DataDreamer.ctx.hf_log: + if self.logger.level <= logging.NOTSET: # pragma: no cover + logging.getLogger( + "transformers.trainer" + ).level = DataDreamer.ctx._transformers_trainer_verbosity + DataDreamer._disable_hf_transformers_logging() + DataDreamer._disable_setfit_logging() + + +class CustomDataCollatorWithPadding: + def __init__( + self, + tokenizer: PreTrainedTokenizer, + fields_to_pad: list[dict[str, Any]], + fields_to_keep: None | list[str] = None, + extra_column_names_to_add: None | dict[str, Any] = None, + ): + self.tokenizer = tokenizer + self.fields_to_pad = fields_to_pad + self.fields_to_keep = fields_to_keep + self.extra_column_names_to_add = extra_column_names_to_add + + def update_pad_token_id( + self, tensor: torch.Tensor, pad_token_id: int, keep_first_pad_token: bool + ): + # Find where the pad tokens are + pad_token_mask = tensor == self.tokenizer.pad_token_id + if keep_first_pad_token: + # Find the indices of the left-most pad token in each row + leftmost_true_indices = pad_token_mask.to(torch.int32).argmax(dim=1) + # Create a mask to help keep the left-most pad_token value + keep_leftmost_mask = ( + torch.arange(pad_token_mask.size(1)) <= leftmost_true_indices[:, None] + ) + # Apply the mask to the original mask + pad_token_mask = pad_token_mask & ~keep_leftmost_mask + # Update the pad token IDs + tensor[pad_token_mask] = pad_token_id + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + result = {} + for field in self.fields_to_pad: + tokenizer = field.get("tokenizer", self.tokenizer) + pad_results = tokenizer.pad( + [{"input_ids": feature[field["name"]]} for feature in features], + padding=True, + return_tensors="pt", + ) + result[field["output_name"]] = pad_results["input_ids"] + if "pad_token_id" in field: + self.update_pad_token_id( + tensor=result[field["output_name"]], + pad_token_id=field["pad_token_id"], + keep_first_pad_token=field.get("keep_first_pad_token", False), + ) + if "output_attention_mask_name" in field: # pragma: no cover + result[field["output_attention_mask_name"]] = pad_results[ + "attention_mask" + ] + if isinstance(self.extra_column_names_to_add, dict): + for ( + column_name, + default_value, + ) in self.extra_column_names_to_add.items(): + result[column_name] = default_value + if self.fields_to_keep is not None: + for field_name in self.fields_to_keep: + result[field_name] = [ + feature[field_name] for feature in features if field_name in feature + ] + if len(result[field_name]) > 0 and isinstance( + result[field_name][0], (bool, int, float, np.ndarray, torch.Tensor) + ): + result[field_name] = torch.tensor(result[field_name]) + elif len(result[field_name]) == 0: + del result[field_name] + return result + + +def get_logging_callback(trainer: "_TrainHFBase", log_loss: bool = True) -> Type: + class LoggingCallback(TrainerCallback): + def on_log(self_, args, state, control, logs=None, **kwargs): + if is_distributed() and get_global_rank() != 0: # pragma: no cover + return + logs = logs.copy() + if "eval_progress" in logs and logs["eval_progress"] == "100%": + return + _ = logs.pop("total_flos", None) + _ = logs.pop("eval_joint_metric", None) + if state.is_local_process_zero: + epoch = logs.pop("epoch", 0.0) + if any([metric.startswith("eval_") for metric in logs.keys()]): + logs = {k.replace("eval_", ""): v for k, v in logs.items()} + if not log_loss: + logs.pop("loss") + trainer.logger.info(f"Eval Epoch: {epoch} -- {logs}") + else: + logs = {k.replace("train_", ""): v for k, v in logs.items()} + if not log_loss: + logs.pop("loss") + trainer.logger.info(f"Train Epoch: {epoch} -- {logs}") + + return LoggingCallback + + +def wrap_compute_metrics(compute_metrics, training_args: "TrainingArguments"): + def _wrapped_compute_metrics(*args, **kwargs): + if not_distributed_or_main_process(): + computed_metrics = compute_metrics(*args, **kwargs) + if is_distributed(): # pragma: no cover + for _ in range(get_local_world_size() - 1): + DataDreamer.ctx.distributed_pipe.put(dill.dumps(computed_metrics)) + return computed_metrics + else: # pragma: no cover + return dill.loads(DataDreamer.ctx.distributed_pipe.get()) + + return _wrapped_compute_metrics if compute_metrics is not None else None diff --git a/src/utils/import_utils.py b/src/utils/import_utils.py index de20fd43..abf7363c 100644 --- a/src/utils/import_utils.py +++ b/src/utils/import_utils.py @@ -1,6 +1,5 @@ import contextlib import importlib -import logging import warnings from types import ModuleType @@ -68,24 +67,6 @@ def ignore_training_warnings(): category=UserWarning, message="Merge.*may get different generations due to rounding error.*", ) - - # Filter warning logs - for model_logger_name in [ - n - for n in logging.Logger.manager.loggerDict.keys() - if n.startswith("transformers.models.") - ]: - model_logger = logging.getLogger(model_logger_name) - - class NoUseCacheIsIncompatibleWarningFilter( - logging.Filter - ): # pragma: no cover - def filter(self, record): - return not record.getMessage().startswith( - "`use_cache=True` is incompatible with gradient checkpointing" - ) - - model_logger.addFilter(NoUseCacheIsIncompatibleWarningFilter()) yield None