Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] make the trainer support resume checkpoint from a named adapter #28531 #28547

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,9 +2038,16 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if model is None:
model = self.model

named_adapter_subfolder = ''
if _is_peft_model(model):
# adapter with adapter_name will be saved in checkpoint/adapter_name subfolder, therefore join the path
# to the subfolder if necessary
named_adapter_subfolder = model.active_adapter if model.active_adapter not in ['default', None] else ''

config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
adapter_weights_file = os.path.join(resume_from_checkpoint, named_adapter_subfolder, ADAPTER_WEIGHTS_NAME)
adapter_safe_weights_file = os.path.join(resume_from_checkpoint, named_adapter_subfolder,
ADAPTER_SAFE_WEIGHTS_NAME)
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
Expand Down Expand Up @@ -2129,8 +2136,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)
adapter_model_path = os.path.join(resume_from_checkpoint, named_adapter_subfolder)
if os.path.exists(adapter_model_path):
model.load_adapter(adapter_model_path, model.active_adapter, is_trainable=True)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
Expand Down