Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] Correct behavior of _load_best_model for PEFT models #24103

Merged
merged 4 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@
)
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import (
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
Expand Down Expand Up @@ -2177,11 +2179,20 @@ def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)

model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
if (
os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path)
or os.path.exists(best_adapter_model_path)
or os.path.exists(best_safe_adapter_model_path)
):
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
else:
has_been_loaded = True
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
Expand All @@ -2207,10 +2218,10 @@ def _load_best_model(self):
self.accelerator, model, self.state.best_model_checkpoint
)
else:
if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False):
# If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
if is_peft_available() and isinstance(model, PeftModel):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")):
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys
Expand All @@ -2219,12 +2230,13 @@ def _load_best_model(self):
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
"using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
"here are some examples https://github.com/huggingface/peft/issues/96"
)
has_been_loaded = False
else:
# We can't do pure 8bit training using transformers.
logger.warning("Could not loading a quantized checkpoint.")
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
has_been_loaded = False
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
Expand All @@ -2236,7 +2248,7 @@ def _load_best_model(self):
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
if not is_sagemaker_mp_enabled():
if not is_sagemaker_mp_enabled() and has_been_loaded:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
load_result = load_sharded_checkpoint(
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@

WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
TF_WEIGHTS_NAME = "model.ckpt"
Expand Down