Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] Correct behavior of _load_best_model for PEFT models #24103

Merged
merged 4 commits into from
Jun 8, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2177,11 +2177,18 @@ def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can also be safetensor ckpts, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe adding best_safe_adapter_model_path should serve the purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect, will refactor that a bit


model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
if (
os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path)
or os.path.exists(adapter_model_path)
):
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
else:
has_been_loaded = True
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api.
Expand All @@ -2207,10 +2214,10 @@ def _load_best_model(self):
self.accelerator, model, self.state.best_model_checkpoint
)
else:
if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False):
# If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly.
if is_peft_available() and isinstance(model, PeftModel):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")):
if os.path.exists(adapter_model_path):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys
Expand All @@ -2222,9 +2229,11 @@ def _load_best_model(self):
"using `TrainerCallback` to save adapter_model.bin in corresponding folders, "
"here are some examples https://github.com/huggingface/peft/issues/96"
)
has_been_loaded = False
else:
# We can't do pure 8bit training using transformers.
logger.warning("Could not loading a quantized checkpoint.")
has_been_loaded = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be removed now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is needed so that it can be used in the block below for the check, otherwise it will throw an error similar as #24096

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AH sorry I see what you meant, yes will remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proposed something in bf31c5e

else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
Expand All @@ -2236,7 +2245,7 @@ def _load_best_model(self):
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
if not is_sagemaker_mp_enabled():
if not is_sagemaker_mp_enabled() and has_been_loaded:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
load_result = load_sharded_checkpoint(
Expand Down