From 2200bf7a45782d42bda73a042606a1abe454a62a Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 8 Jun 2023 15:38:30 +0200 Subject: [PATCH] [`Trainer`] Correct behavior of `_load_best_model` for PEFT models (#24103) * v1 * some refactor - add ST format as well * fix * add `ADAPTER_WEIGHTS_NAME` & `ADAPTER_SAFE_WEIGHTS_NAME` --- src/transformers/trainer.py | 28 ++++++++++++++++++++-------- src/transformers/utils/__init__.py | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3231558dec18..d93e6b587de0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -134,6 +134,8 @@ ) from .training_args import OptimizerNames, ParallelMode, TrainingArguments from .utils import ( + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -2177,11 +2179,20 @@ def _load_best_model(self): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) + best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) + best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) + model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path): + if ( + os.path.exists(best_model_path) + or os.path.exists(best_safe_model_path) + or os.path.exists(best_adapter_model_path) + or os.path.exists(best_safe_adapter_model_path) + ): if self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) else: + has_been_loaded = True if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): # If the 'user_content.pt' file exists, load with the new smp api. @@ -2207,10 +2218,10 @@ def _load_best_model(self): self.accelerator, model, self.state.best_model_checkpoint ) else: - if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False): - # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly. + if is_peft_available() and isinstance(model, PeftModel): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): - if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")): + if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) # Load_adapter has no return value present, modify it when appropriate. from torch.nn.modules.module import _IncompatibleKeys @@ -2219,12 +2230,13 @@ def _load_best_model(self): else: logger.warning( "The intermediate checkpoints of PEFT may not be saved correctly, " - "using `TrainerCallback` to save adapter_model.bin in corresponding folders, " + f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, " "here are some examples https://github.com/huggingface/peft/issues/96" ) + has_been_loaded = False else: - # We can't do pure 8bit training using transformers. - logger.warning("Could not loading a quantized checkpoint.") + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") + has_been_loaded = False else: # We load the model state dict on the CPU to avoid an OOM error. if self.args.save_safetensors and os.path.isfile(best_safe_model_path): @@ -2236,7 +2248,7 @@ def _load_best_model(self): # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs load_result = model.load_state_dict(state_dict, False) - if not is_sagemaker_mp_enabled(): + if not is_sagemaker_mp_enabled() and has_been_loaded: self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): load_result = load_sharded_checkpoint( diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7169c7daf969..3aa1f8aeb926 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -177,6 +177,8 @@ WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" +ADAPTER_WEIGHTS_NAME = "adapter_model.bin" +ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" TF2_WEIGHTS_NAME = "tf_model.h5" TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" TF_WEIGHTS_NAME = "model.ckpt"