diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 38cc562c9d..7abe4dcf75 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -11,6 +11,8 @@ from llmfoundry.utils.config_utils import (calculate_batch_size_info, log_config, pop_config, update_batch_size_info) + from llmfoundry.utils.model_download_utils import ( + download_from_cache_server, download_from_hf_hub) except ImportError as e: raise ImportError( 'Please make sure to pip install . to get requirements for llm-foundry.' @@ -26,6 +28,8 @@ 'build_tokenizer', 'calculate_batch_size_info', 'convert_and_save_ft_weights', + 'download_from_cache_server', + 'download_from_hf_hub', 'get_hf_tokenizer_from_composer_state_dict', 'update_batch_size_info', 'log_config', diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py new file mode 100644 index 0000000000..d268cb78b7 --- /dev/null +++ b/llmfoundry/utils/model_download_utils.py @@ -0,0 +1,228 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for downloading models.""" +import copy +import logging +import os +import time +from http import HTTPStatus +from typing import Optional +from urllib.parse import urljoin + +import huggingface_hub as hf_hub +import requests +import tenacity +from bs4 import BeautifulSoup +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME +from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME + +DEFAULT_IGNORE_PATTERNS = [ + '*.ckpt', + '*.h5', + '*.msgpack', +] +PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*' +SAFE_WEIGHTS_PATTERN = 'model*.safetensors*' + +log = logging.getLogger(__name__) + + +@tenacity.retry(retry=tenacity.retry_if_not_exception_type( + (ValueError, hf_hub.utils.RepositoryNotFoundError)), + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(min=1, max=10)) +def download_from_hf_hub( + repo_id: str, + save_dir: Optional[str] = None, + prefer_safetensors: bool = True, + token: Optional[str] = None, +): + """Downloads model files from a Hugging Face Hub model repo. + + Only supports models stored in Safetensors and PyTorch formats for now. If both formats are available, only the + Safetensors weights will be downloaded unless `prefer_safetensors` is set to False. + + Args: + repo_id (str): The Hugging Face Hub repo ID. + save_dir (str, optional): The path to the directory where the model files will be downloaded. If `None`, reads + from the `HUGGINGFACE_HUB_CACHE` environment variable or uses the default Hugging Face Hub cache directory. + prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are + available. Defaults to True. + token (str, optional): The HuggingFace API token. If not provided, the token will be read from the + `HUGGING_FACE_HUB_TOKEN` environment variable. + + Raises: + RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized. + ValueError: If the model repo doesn't contain any supported model weights. + """ + repo_files = set(hf_hub.list_repo_files(repo_id)) + + # Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer. + ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS) + + safetensors_available = (SAFE_WEIGHTS_NAME in repo_files or + SAFE_WEIGHTS_INDEX_NAME in repo_files) + pytorch_available = (PYTORCH_WEIGHTS_NAME in repo_files or + PYTORCH_WEIGHTS_INDEX_NAME in repo_files) + + if safetensors_available and pytorch_available: + if prefer_safetensors: + log.info( + 'Safetensors available and preferred. Excluding pytorch weights.' + ) + ignore_patterns.append(PYTORCH_WEIGHTS_PATTERN) + else: + log.info( + 'Pytorch available and preferred. Excluding safetensors weights.' + ) + ignore_patterns.append(SAFE_WEIGHTS_PATTERN) + elif safetensors_available: + log.info('Only safetensors available. Ignoring weights preference.') + elif pytorch_available: + log.info('Only pytorch available. Ignoring weights preference.') + else: + raise ValueError( + f'No supported model weights found in repo {repo_id}.' + + ' Please make sure the repo contains either safetensors or pytorch weights.' + ) + + download_start = time.time() + hf_hub.snapshot_download(repo_id, + cache_dir=save_dir, + ignore_patterns=ignore_patterns, + token=token) + download_duration = time.time() - download_start + log.info( + f'Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds' + ) + + +def _extract_links_from_html(html: str): + """Extracts links from HTML content. + + Args: + html (str): The HTML content + + Returns: + list[str]: A list of links to download. + """ + soup = BeautifulSoup(html, 'html.parser') + links = [a['href'] for a in soup.find_all('a')] + return links + + +def _recursive_download( + session: requests.Session, + base_url: str, + path: str, + save_dir: str, + ignore_cert: bool = False, +): + """Downloads all files/subdirectories from a directory on a remote server. + + Args: + session: A requests.Session through which to make requests to the remote server. + url (str): The base URL where the files are located. + path (str): The path from the base URL to the files to download. The full URL for the download is equal to + '/'. + save_dir (str): The directory to save downloaded files to. + ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server. + Defaults to False. + WARNING: Setting this to true is *not* secure, as no certificate verification will be performed. + + Raises: + PermissionError: If the remote server returns a 401 Unauthorized status code. + ValueError: If the remote server returns a 404 Not Found status code. + RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized. + """ + url = urljoin(base_url, path) + response = session.get(url, verify=(not ignore_cert)) + + if response.status_code == HTTPStatus.UNAUTHORIZED: + raise PermissionError( + f'Not authorized to download file from {url}. Received status code {response.status_code}. ' + ) + elif response.status_code == HTTPStatus.NOT_FOUND: + raise ValueError( + f'Could not find file at {url}. Received status code {response.status_code}' + ) + elif response.status_code != HTTPStatus.OK: + raise RuntimeError( + f'Could not download file from {url}. Received unexpected status code {response.status_code}' + ) + + # Assume that the URL points to a file if it does not end with a slash. + if not path.endswith('/'): + save_path = os.path.join(save_dir, path) + parent_dir = os.path.dirname(save_path) + if not os.path.exists(parent_dir): + os.makedirs(parent_dir) + + with open(save_path, 'wb') as f: + f.write(response.content) + + log.info(f'Downloaded file {save_path}') + return + + # If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links + # to download. + child_links = _extract_links_from_html(response.content.decode()) + for child_link in child_links: + _recursive_download(session, + base_url, + urljoin(path, child_link), + save_dir, + ignore_cert=ignore_cert) + + +@tenacity.retry(retry=tenacity.retry_if_not_exception_type( + (PermissionError, ValueError)), + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(min=1, max=10)) +def download_from_cache_server( + model_name: str, + cache_base_url: str, + save_dir: str, + token: Optional[str] = None, + ignore_cert: bool = False, +): + """Downloads Hugging Face models from a mirror file server. + + The file server is expected to store the files in the same structure as the Hugging Face cache + structure. See https://huggingface.co/docs/huggingface_hub/guides/manage-cache. + + Args: + model_name: The name of the model to download. This should be the same as the repository ID in the Hugging Face + Hub. + cache_base_url: The base URL of the cache file server. This function will attempt to download all of the blob + files from `//blobs/`, where `formatted_model_name` is equal to + `models/` with all slashes replaced with `--`. + save_dir: The directory to save the downloaded files to. + token: The Hugging Face API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN` + environment variable. + ignore_cert: Whether or not to ignore the validity of the SSL certificate of the remote server. Defaults to + False. + WARNING: Setting this to true is *not* secure, as no certificate verification will be performed. + """ + formatted_model_name = f'models/{model_name}'.replace('/', '--') + with requests.Session() as session: + session.headers.update({'Authorization': f'Bearer {token}'}) + + download_start = time.time() + + # Only downloads the blobs in order to avoid downloading model files twice due to the + # symlnks in the Hugging Face cache structure: + _recursive_download( + session, + cache_base_url, + # Trailing slash to indicate directory + f'{formatted_model_name}/blobs/', + save_dir, + ignore_cert=ignore_cert, + ) + download_duration = time.time() - download_start + log.info( + f'Downloaded model {model_name} from cache server in {download_duration} seconds' + ) diff --git a/scripts/misc/download_hf_model.py b/scripts/misc/download_hf_model.py new file mode 100644 index 0000000000..6465a552c2 --- /dev/null +++ b/scripts/misc/download_hf_model.py @@ -0,0 +1,67 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Script to download model weights from Hugging Face Hub or a cache server.""" +import argparse +import logging +import os +import sys + +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE + +from llmfoundry.utils.model_download_utils import (download_from_cache_server, + download_from_hf_hub) + +HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN' + +log = logging.getLogger(__name__) + +if __name__ == '__main__': + argparser = argparse.ArgumentParser() + argparser.add_argument('--model', type=str, required=True) + argparser.add_argument('--download-from', + type=str, + choices=['hf', 'cache'], + default='hf') + argparser.add_argument('--token', + type=str, + default=os.getenv(HF_TOKEN_ENV_VAR)) + argparser.add_argument('--save-dir', + type=str, + default=HUGGINGFACE_HUB_CACHE) + argparser.add_argument('--cache-url', type=str, default=None) + argparser.add_argument('--ignore-cert', action='store_true', default=False) + argparser.add_argument( + '--fallback', + action='store_true', + default=False, + help= + 'Whether to fallback to downloading from Hugging Face if download from cache fails', + ) + + args = argparser.parse_args(sys.argv[1:]) + if args.download_from == 'hf': + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token) + else: + try: + download_from_cache_server( + args.model, + args.cache_url, + args.save_dir, + token=args.token, + ignore_cert=args.ignore_cert, + ) + except PermissionError: + log.error(f'Not authorized to download {args.model}.') + except Exception as e: + if args.fallback: + log.warn( + f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}' + ) + download_from_hf_hub(args.model, + save_dir=args.save_dir, + token=args.token) + else: + raise e diff --git a/setup.py b/setup.py index 63aac9d752..f528838d35 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,8 @@ 'triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir_sm90#subdirectory=python', 'boto3>=1.21.45,<2', 'huggingface-hub>=0.17.0,<1.0', + 'beautifulsoup4>=4.12.2,<5', # required for model download utils + 'tenacity>=8.2.3,<9', ] extra_deps = {} @@ -101,7 +103,8 @@ extra_deps['peft'] = [ 'loralib==0.1.1', # lora core 'bitsandbytes==0.39.1', # 8bit - 'scipy>=1.10.0,<=1.11.0', # bitsandbytes dependency; TODO: eliminate when incorporated to bitsandbytes + # bitsandbytes dependency; TODO: eliminate when incorporated to bitsandbytes + 'scipy>=1.10.0,<=1.11.0', # TODO: pin peft when it stabilizes. # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'peft==0.4.0', diff --git a/tests/test_model_download_utils.py b/tests/test_model_download_utils.py new file mode 100644 index 0000000000..27b9805cda --- /dev/null +++ b/tests/test_model_download_utils.py @@ -0,0 +1,248 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import unittest.mock as mock +from http import HTTPStatus +from typing import Any, Dict, List +from unittest.mock import MagicMock +from urllib.parse import urljoin + +import pytest +import requests +import tenacity +from huggingface_hub.utils import RepositoryNotFoundError +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME +from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME + +from llmfoundry.utils.model_download_utils import (DEFAULT_IGNORE_PATTERNS, + PYTORCH_WEIGHTS_PATTERN, + SAFE_WEIGHTS_PATTERN, + download_from_cache_server, + download_from_hf_hub) + +# ======================== download_from_hf_hub tests ======================== + + +@pytest.mark.parametrize( + ['prefer_safetensors', 'repo_files', 'expected_ignore_patterns'], + [ + [ # Should use default ignore if only safetensors available + True, + [SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only safetensors available + False, + [SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ # Should use default ignore if only sharded safetensors available + True, + [SAFE_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only sharded safetensors available + False, + [SAFE_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only pytorch available + True, + [PYTORCH_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only pytorch available + False, + [PYTORCH_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only sharded pytorch available + True, + [PYTORCH_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ + # Should use default ignore if only sharded pytorch available + False, + [PYTORCH_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS, + ], + [ # Ignore pytorch if safetensors are preferred + True, + [PYTORCH_WEIGHTS_NAME, SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS + [PYTORCH_WEIGHTS_PATTERN], + ], + [ # Ignore safetensors if pytorch is preferred + False, + [PYTORCH_WEIGHTS_NAME, SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS + [SAFE_WEIGHTS_PATTERN], + ], + [ # Ignore pytorch if safetensors are preferred + True, + [PYTORCH_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME], + DEFAULT_IGNORE_PATTERNS + [PYTORCH_WEIGHTS_PATTERN], + ], + [ # Ignore safetensors if pytorch is preferred + False, + [PYTORCH_WEIGHTS_NAME, SAFE_WEIGHTS_NAME], + DEFAULT_IGNORE_PATTERNS + [SAFE_WEIGHTS_PATTERN], + ], + ]) +@mock.patch('huggingface_hub.snapshot_download') +@mock.patch('huggingface_hub.list_repo_files') +def test_download_from_hf_hub_weights_pref(mock_list_repo_files: MagicMock, + mock_snapshot_download: MagicMock, + prefer_safetensors: bool, + repo_files: List[str], + expected_ignore_patterns: List[str]): + test_repo_id = 'test_repo_id' + mock_list_repo_files.return_value = repo_files + + download_from_hf_hub(test_repo_id, prefer_safetensors=prefer_safetensors) + mock_snapshot_download.assert_called_once_with( + test_repo_id, + cache_dir=None, + ignore_patterns=expected_ignore_patterns, + token=None, + ) + + +@mock.patch('huggingface_hub.snapshot_download') +@mock.patch('huggingface_hub.list_repo_files') +def test_download_from_hf_hub_no_weights( + mock_list_repo_files: MagicMock, + mock_snapshot_download: MagicMock, +): + test_repo_id = 'test_repo_id' + mock_list_repo_files.return_value = [] + + with pytest.raises(ValueError): + download_from_hf_hub(test_repo_id) + + mock_snapshot_download.assert_not_called() + + +@pytest.mark.parametrize(['exception', 'expected_attempts'], [ + [requests.exceptions.RequestException(), 3], + [RepositoryNotFoundError(''), 1], + [ValueError(), 1], +]) +@mock.patch('tenacity.nap.time.sleep') +@mock.patch('huggingface_hub.snapshot_download') +@mock.patch('huggingface_hub.list_repo_files') +def test_download_from_hf_hub_retry( + mock_list_repo_files: MagicMock, + mock_snapshot_download: MagicMock, + mock_sleep: MagicMock, # so the retry wait doesn't actually wait + exception: BaseException, + expected_attempts: int, +): + mock_list_repo_files.return_value = [SAFE_WEIGHTS_INDEX_NAME] + mock_snapshot_download.side_effect = exception + + with pytest.raises((tenacity.RetryError, exception.__class__)): + download_from_hf_hub('test_repo_id') + + assert mock_snapshot_download.call_count == expected_attempts + + +# ======================== download_from_cache_server tests ======================== + +ROOT_HTML = b""" + + + + + + +""" + +SUBFOLDER_HTML = b""" + + + + + + +""" + + +@mock.patch.object(requests.Session, 'get') +@mock.patch('os.makedirs') +@mock.patch('builtins.open') +def test_download_from_cache_server(mock_open: MagicMock, + mock_makedirs: MagicMock, + mock_get: MagicMock): + cache_url = 'https://cache.com/' + model_name = 'model' + formatted_model_name = 'models--model' + save_dir = 'save_dir/' + + mock_open.return_value = MagicMock() + + def _server_response(url: str, **kwargs: Dict[str, Any]): + if url == urljoin(cache_url, f'{formatted_model_name}/blobs/'): + return MagicMock(status_code=HTTPStatus.OK, content=ROOT_HTML) + if url == urljoin(cache_url, f'{formatted_model_name}/blobs/file1'): + return MagicMock(status_code=HTTPStatus.OK) + elif url == urljoin(cache_url, f'{formatted_model_name}/blobs/folder/'): + return MagicMock(status_code=HTTPStatus.OK, content=SUBFOLDER_HTML) + elif url == urljoin(cache_url, + f'{formatted_model_name}/blobs/folder/file2'): + return MagicMock(status_code=HTTPStatus.OK) + else: + return MagicMock(status_code=HTTPStatus.NOT_FOUND) + + mock_get.side_effect = _server_response + download_from_cache_server(model_name, cache_url, 'save_dir/') + + mock_open.assert_has_calls([ + mock.call(os.path.join(save_dir, formatted_model_name, 'blobs/file1'), + 'wb'), + mock.call( + os.path.join(save_dir, formatted_model_name, 'blobs/folder/file2'), + 'wb'), + ], + any_order=True) + + +@mock.patch.object(requests.Session, 'get') +def test_download_from_cache_server_unauthorized(mock_get: MagicMock): + cache_url = 'https://cache.com/' + model_name = 'model' + save_dir = 'save_dir/' + + mock_get.return_value = MagicMock(status_code=HTTPStatus.UNAUTHORIZED) + with pytest.raises(PermissionError): + download_from_cache_server(model_name, cache_url, save_dir) + + +@pytest.mark.parametrize(['exception', 'expected_attempts'], [ + [requests.exceptions.RequestException(), 3], + [PermissionError(), 1], + [ValueError(), 1], +]) +@mock.patch('tenacity.nap.time.sleep') +@mock.patch('llmfoundry.utils.model_download_utils._recursive_download') +def test_download_from_cache_server_retry( + mock_recursive_download: MagicMock, + mock_sleep: MagicMock, # so the retry wait doesn't actually wait + exception: BaseException, + expected_attempts: int, +): + mock_recursive_download.side_effect = exception + + with pytest.raises((tenacity.RetryError, exception.__class__)): + download_from_cache_server('model', 'cache_url', 'save_dir')