-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Trainer
] Correct behavior of _load_best_model
for PEFT models
#24103
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2177,11 +2177,18 @@ def _load_best_model(self): | |
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") | ||
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) | ||
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) | ||
adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin") | ||
|
||
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | ||
if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path): | ||
if ( | ||
os.path.exists(best_model_path) | ||
or os.path.exists(best_safe_model_path) | ||
or os.path.exists(adapter_model_path) | ||
): | ||
if self.is_deepspeed_enabled: | ||
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) | ||
else: | ||
has_been_loaded = True | ||
if is_sagemaker_mp_enabled(): | ||
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): | ||
# If the 'user_content.pt' file exists, load with the new smp api. | ||
|
@@ -2207,10 +2214,10 @@ def _load_best_model(self): | |
self.accelerator, model, self.state.best_model_checkpoint | ||
) | ||
else: | ||
if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False): | ||
# If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly. | ||
if is_peft_available() and isinstance(model, PeftModel): | ||
# If train a model using PEFT & LoRA, assume that adapter have been saved properly. | ||
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): | ||
if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")): | ||
if os.path.exists(adapter_model_path): | ||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) | ||
# Load_adapter has no return value present, modify it when appropriate. | ||
from torch.nn.modules.module import _IncompatibleKeys | ||
|
@@ -2222,9 +2229,11 @@ def _load_best_model(self): | |
"using `TrainerCallback` to save adapter_model.bin in corresponding folders, " | ||
"here are some examples https://github.com/huggingface/peft/issues/96" | ||
) | ||
has_been_loaded = False | ||
else: | ||
# We can't do pure 8bit training using transformers. | ||
logger.warning("Could not loading a quantized checkpoint.") | ||
has_been_loaded = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be removed now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is needed so that it can be used in the block below for the check, otherwise it will throw an error similar as #24096 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AH sorry I see what you meant, yes will remove it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. proposed something in bf31c5e |
||
else: | ||
# We load the model state dict on the CPU to avoid an OOM error. | ||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path): | ||
|
@@ -2236,7 +2245,7 @@ def _load_best_model(self): | |
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 | ||
# which takes *args instead of **kwargs | ||
load_result = model.load_state_dict(state_dict, False) | ||
if not is_sagemaker_mp_enabled(): | ||
if not is_sagemaker_mp_enabled() and has_been_loaded: | ||
self._issue_warnings_after_load(load_result) | ||
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): | ||
load_result = load_sharded_checkpoint( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can also be safetensor ckpts, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe adding
best_safe_adapter_model_path
should serve the purpose?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect, will refactor that a bit