Skip to content

Commit

Permalink
Merge branch 'main' into depwarn
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Feb 6, 2024
2 parents 835d20e + 2e0a845 commit 2e444b9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
36 changes: 32 additions & 4 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
maybe_create_remote_uploader_downloader_from_uri,
parse_uri)
from composer.utils.misc import create_interval_scheduler
from mlflow.transformers import _fetch_model_card, _write_license_information
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
Expand All @@ -32,17 +33,41 @@
_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE)


def _maybe_get_license_filename(local_dir: str) -> Optional[str]:
def _maybe_get_license_filename(
local_dir: str,
pretrained_model_name: Optional[str] = None) -> Optional[str]:
"""Returns the name of the license file if it exists in the local_dir.
Note: This is intended to be consistent with the code in MLflow.
https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152
Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub,
MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for
a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name,
in which case this function will use that to fetch the correct license information.
If the license file does not exist, returns None.
"""
try:
return next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))
license_filename = next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))

# If a pretrained model name is provided, replace the license file with the correct info from HF Hub.
if pretrained_model_name is not None:
log.info(
f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub'
)
os.remove(os.path.join(local_dir, license_filename))
model_card = _fetch_model_card(pretrained_model_name)

local_dir_path = Path(local_dir).absolute()
_write_license_information(pretrained_model_name, model_card,
local_dir_path)
license_filename = next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))

return license_filename

except StopIteration:
return None

Expand Down Expand Up @@ -330,8 +355,11 @@ def _save_checkpoint(self, state: State, logger: Logger):

mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
local_save_path)
local_save_path,
self.mlflow_logging_config['metadata'].get(
'pretrained_model_name', None))
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]
extra_deps['gpu-flash2'] = [
'flash-attn==2.4.2',
'flash-attn==2.5.0',
'mosaicml-turbo==0.0.8',
]

Expand Down

0 comments on commit 2e444b9

Please sign in to comment.