Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] make the trainer support resume checkpoint from a named adapter #28531 #28547

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2149,9 +2149,17 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if model is None:
model = self.model

named_adapter_subfolder = ""
if _is_peft_model(model):
# adapter with adapter_name will be saved in checkpoint/adapter_name subfolder, therefore join the path
# to the subfolder if necessary
named_adapter_subfolder = model.active_adapter if model.active_adapter not in ["default", None] else ""

config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
adapter_weights_file = os.path.join(resume_from_checkpoint, named_adapter_subfolder, ADAPTER_WEIGHTS_NAME)
adapter_safe_weights_file = os.path.join(
resume_from_checkpoint, named_adapter_subfolder, ADAPTER_SAFE_WEIGHTS_NAME
)
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
Expand Down Expand Up @@ -2255,8 +2263,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)
adapter_model_path = os.path.join(resume_from_checkpoint, named_adapter_subfolder)
if os.path.exists(adapter_model_path):
model.load_adapter(adapter_model_path, model.active_adapter, is_trainable=True)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
Expand Down
113 changes: 113 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,11 @@ def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True):
for param_name, shard_file in zip(keys, shard_files):
saver({param_name: state_dict[param_name]}, os.path.join(folder, shard_file))

def check_state_dict_are_the_same(self, state_dict_1, state_dict_2):
self.assertEquals(state_dict_1.keys(), state_dict_2.keys())
for key in state_dict_1.keys():
self.assertTrue(torch.equal(state_dict_1[key], state_dict_2[key]))


@require_torch
@require_sentencepiece
Expand Down Expand Up @@ -3451,3 +3456,111 @@ def test_hyperparameter_search_backends(self):
list(ALL_HYPERPARAMETER_SEARCH_BACKENDS.keys()),
list(HPSearchBackend),
)


@require_peft
@require_torch
class TrainerPeftTest(unittest.TestCase, TrainerIntegrationCommon):
def get_regression_adapter_trainer(
self,
keep_report_to=False,
adapter_name="default",
adapter_kwargs=None,
**kwargs,
):
# from transformers import OPTForCausalLM
base_model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
pretrained_base_model = AutoModelForCausalLM.from_pretrained(base_model_id).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(base_model_id)

def tokenize_sample(sample):
return tokenizer(
sample["quote"],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
)

from datasets import load_dataset

data = load_dataset("Abirate/english_quotes")
data = (
data["train"]
.filter(lambda example, indice: indice < 9, with_indices=True)
.map(tokenize_sample, batched=True)
)

# generate peft
from peft import LoraConfig, get_peft_model

adapter_kwargs = (
adapter_kwargs
if adapter_kwargs is not None
else {
"r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"bias": "none",
"target_modules": ["q_proj"],
"task_type": "CAUSAL_LM",
}
)

lora_config = LoraConfig(**adapter_kwargs)

peft_model = get_peft_model(
model=pretrained_base_model,
peft_config=lora_config,
adapter_name=adapter_name,
)

compute_metrics = kwargs.pop("compute_metrics", None)
optimizers = kwargs.pop("optimizers", (None, None))
output_dir = kwargs.pop("output_dir", "./regression")
preprocess_logits_for_metrics = kwargs.pop("preprocess_logits_for_metrics", None)
Comment on lines +3518 to +3521
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to keep logic for tests as simple and explicit as possible. Here, rather than popping from kwargs, these should be kwargs with default values in get_regression_adapter_trainer

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically you are right. I imitated the coding style of other test case in test_trainer.py, there mixes some arges in kwargs and pops them later. What do you think? To keep the coding style identity, or I could seperate them in method args, that's not a big effort.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather they were explicitly set, or, tbh because this method is only used in one place, just hardcoded


args = RegressionTrainingArguments(output_dir, keep_report_to=keep_report_to, **kwargs)
return Trainer(
peft_model,
args,
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
train_dataset=data,
compute_metrics=compute_metrics,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

def test_can_resume_named_adapter_training(self):
"""
test a named adapter can be resumed correctly
"""
from peft import get_peft_model_state_dict

adapter_name = "test_adapter"

with tempfile.TemporaryDirectory() as tmpdir:
kwargs = {
"output_dir": tmpdir,
"save_steps": 5,
"learning_rate": 0.1,
"logging_steps": 5,
"adapter_name": adapter_name,
}
trainer = self.get_regression_adapter_trainer(**kwargs)
trainer.train()

adapter_state_dict = get_peft_model_state_dict(model=trainer.model, adapter_name=adapter_name)
state = dataclasses.asdict(trainer.state)

checkpoint = os.path.join(tmpdir, "checkpoint-5")

# Reinitialize trainer
trainer = self.get_regression_adapter_trainer(**kwargs)
# resume training from last checkpoint
trainer.train(resume_from_checkpoint=checkpoint)

adapter_state_dict_resumed = get_peft_model_state_dict(model=trainer.model, adapter_name=adapter_name)
state_resumed = dataclasses.asdict(trainer.state)
self.check_state_dict_are_the_same(adapter_state_dict, adapter_state_dict_resumed)
self.check_trainer_state_are_the_same(state, state_resumed)