Skip to content

Commit

Permalink
Set pretrained model name correctly, if provided, in HF Checkpointer (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 authored Jul 29, 2024
1 parent 6d5d016 commit 6f4aa8c
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ def __init__(
)

self.mlflow_logging_config = mlflow_logging_config
if 'metadata' in self.mlflow_logging_config:
self.pretrained_model_name = self.mlflow_logging_config[
'metadata'].get(
'pretrained_model_name',
None,
)
else:
self.pretrained_model_name = None

self.huggingface_folder_name_fstr = os.path.join(
'huggingface',
Expand Down Expand Up @@ -529,6 +537,16 @@ def tensor_hook(
original_tokenizer,
)

# Ensure that the pretrained model name is correctly set on the saved HF checkpoint.
if self.pretrained_model_name is not None:
new_model_instance.name_or_path = self.pretrained_model_name
if self.using_peft:
new_model_instance.base_model.name_or_path = self.pretrained_model_name
for k in new_model_instance.peft_config.keys():
new_model_instance.peft_config[
k
].base_model_name_or_path = self.pretrained_model_name

log.debug('Saving Hugging Face checkpoint to disk')
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
Expand Down Expand Up @@ -624,10 +642,7 @@ def tensor_hook(
# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
local_save_path,
self.mlflow_logging_config['metadata'].get(
'pretrained_model_name',
None,
),
self.pretrained_model_name,
)
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
Expand Down

0 comments on commit 6f4aa8c

Please sign in to comment.