Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hugging Face model download script #708

Merged
merged 22 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from llmfoundry.utils.config_utils import (calculate_batch_size_info,
log_config, pop_config,
update_batch_size_info)
from llmfoundry.utils.model_download_utils import (download_from_cache_server,
download_from_hf_hub)
except ImportError as e:
raise ImportError(
'Please make sure to pip install . to get requirements for llm-foundry.'
Expand All @@ -26,6 +28,8 @@
'build_tokenizer',
'calculate_batch_size_info',
'convert_and_save_ft_weights',
'download_from_cache_server',
'download_from_hf_hub',
'get_hf_tokenizer_from_composer_state_dict',
'update_batch_size_info',
'log_config',
Expand Down
213 changes: 213 additions & 0 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Utility functions for downloading models.

Copyright 2022 MosaicML LLM Foundry authors
SPDX-License-Identifier: Apache-2.0
"""
import copy
import logging
import os
import time

from typing import Optional
from http import HTTPStatus
from urllib.parse import urljoin

from bs4 import BeautifulSoup
import huggingface_hub as hf_hub
import requests
from transformers.utils import (
WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
)

DEFAULT_IGNORE_PATTERNS = [
'*.ckpt',
'*.h5',
'*.msgpack',
]
PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'

log = logging.getLogger(__name__)


def download_from_hf_hub(
repo_id: str,
save_dir: Optional[str] = None,
prefer_safetensors: bool = True,
token: Optional[str] = None,
):
"""Downloads model files from a Hugging Face Hub model repo.

Only supports models stored in Safetensors and PyTorch formats for now. If both formats are available, only the
Safetensors weights will be downloaded unless `prefer_safetensors` is set to False.

Args:
repo_id (str): The Hugging Face Hub repo ID.
save_dir (str, optional): The path to the directory where the model files will be downloaded. If `None`, reads
from the `HUGGINGFACE_HUB_CACHE` environment variable or uses the default Hugging Face Hub cache directory.
prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
available. Defaults to True.
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
`HUGGING_FACE_HUB_TOKEN` environment variable.
"""
repo_files = set(hf_hub.list_repo_files(repo_id))

# Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer.
ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS)

safetensors_available = (
SAFE_WEIGHTS_NAME in repo_files or SAFE_WEIGHTS_INDEX_NAME in repo_files
)
pytorch_available = (
PYTORCH_WEIGHTS_NAME in repo_files or PYTORCH_WEIGHTS_INDEX_NAME in repo_files
)

if safetensors_available and pytorch_available:
if prefer_safetensors:
log.info(
'Safetensors available and preferred. Excluding pytorch weights.')
ignore_patterns.append(PYTORCH_WEIGHTS_PATTERN)
else:
log.info(
'Pytorch available and preferred. Excluding safetensors weights.')
ignore_patterns.append(SAFE_WEIGHTS_PATTERN)
elif safetensors_available:
log.info('Only safetensors available. Ignoring weights preference.')
elif pytorch_available:
log.info('Only pytorch available. Ignoring weights preference.')
else:
raise ValueError(
f'No supported model weights found in repo {repo_id}.'
+ ' Please make sure the repo contains either safetensors or pytorch weights.'
)

download_start = time.time()
hf_hub.snapshot_download(
repo_id, cache_dir=save_dir, ignore_patterns=ignore_patterns, token=token
)
download_duration = time.time() - download_start
log.info(
f'Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds'
)


def _extract_links_from_html(html: str):
"""Extracts links from HTML content.

Args:
html (str): The HTML content

Returns:
list[str]: A list of links to download.
"""
soup = BeautifulSoup(html, 'html.parser')
links = [a['href'] for a in soup.find_all('a')]
return links


def _recursive_download(
session: requests.Session,
base_url: str,
path: str,
save_dir: str,
ignore_cert: bool = False,
):
"""Downloads all files/subdirectories from a directory on a remote server.

Args:
session: A requests.Session through which to make requests to the remote server.
url (str): The base URL where the files are located.
path (str): The path from the base URL to the files to download. The full URL for the download is equal to
'<base_url>/<path>'.
save_dir (str): The directory to save downloaded files to.
ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server.
Defaults to False.
WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.

Raises:
PermissionError: If the remote server returns a 401 Unauthorized status code.
RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
"""

url = urljoin(base_url, path)
response = session.get(url, verify=(not ignore_cert))

if response.status_code == HTTPStatus.UNAUTHORIZED:
raise PermissionError(
f'Not authorized to download file from {url}. Received status code {response.status_code}. '
)
elif response.status_code != HTTPStatus.OK:
raise RuntimeError(
f'Could not download file from {url}. Received unexpected status code {response.status_code}'
)

# Assume that the URL points to a file if it does not end with a slash.
if not path.endswith('/'):
save_path = os.path.join(save_dir, path)
parent_dir = os.path.dirname(save_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)

with open(save_path, 'wb') as f:
f.write(response.content)

log.info(f'Downloaded file {save_path}')
return

# If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links
# to download.
child_links = _extract_links_from_html(response.content.decode())
for child_link in child_links:
_recursive_download(
session, base_url, urljoin(path, child_link), save_dir, ignore_cert=ignore_cert
)


def download_from_cache_server(
model_name: str,
cache_base_url: str,
save_dir: str,
token: Optional[str] = None,
ignore_cert: bool = False,
):
"""Downloads Hugging Face models from a mirror file server.

The file server is expected to store the files in the same structure as the Hugging Face cache
structure. See https://huggingface.co/docs/huggingface_hub/guides/manage-cache.

Args:
model_name: The name of the model to download. This should be the same as the repository ID in the Hugging Face
Hub.
cache_base_url: The base URL of the cache file server. This function will attempt to download all of the blob
files from `<cache_base_url>/<formatted_model_name>/blobs/`, where `formatted_model_name` is equal to
`models/<model_name>` with all slashes replaced with `--`.
save_dir: The directory to save the downloaded files to.
token: The Hugging Face API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN`
environment variable.
ignore_cert: Whether or not to ignore the validity of the SSL certificate of the remote server. Defaults to
False.
WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
"""
formatted_model_name = f'models/{model_name}'.replace('/', '--')
with requests.Session() as session:
session.headers.update({'Authorization': f'Bearer {token}'})

download_start = time.time()

# Only downloads the blobs in order to avoid downloading model files twice due to the
# symlnks in the Hugging Face cache structure:
_recursive_download(
session,
cache_base_url,
# Trailing slash to indicate directory
f'{formatted_model_name}/blobs/',
save_dir,
ignore_cert=ignore_cert,
)
download_duration = time.time() - download_start
log.info(
f'Downloaded model {model_name} from cache server in {download_duration} seconds'
)
62 changes: 62 additions & 0 deletions scripts/misc/download_hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Script to download model weights from Hugging Face Hub or a cache server.

Copyright 2022 MosaicML LLM Foundry authors
SPDX-License-Identifier: Apache-2.0
"""
import logging
import os
import sys

import argparse

from llmfoundry.utils.model_download_utils import (
download_from_cache_server,
download_from_hf_hub,
)
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE

HF_TOKEN_ENV_VAR = 'HUGGING_FACE_HUB_TOKEN'

log = logging.getLogger(__name__)

if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--model', type=str, required=True)
argparser.add_argument(
'--download-from', type=str, choices=['hf', 'cache'], default='hf'
)
argparser.add_argument('--token', type=str, default=os.getenv(HF_TOKEN_ENV_VAR))
argparser.add_argument('--save-dir', type=str, default=HUGGINGFACE_HUB_CACHE)
argparser.add_argument('--cache-url', type=str, default=None)
argparser.add_argument('--ignore-cert', action='store_true', default=False)
argparser.add_argument(
'--fallback',
action='store_true',
default=False,
help='Whether to fallback to downloading from Hugging Face if download from cache fails',
)

args = argparser.parse_args(sys.argv[1:])
if args.download_from == 'hf':
download_from_hf_hub(args.model, save_dir=args.save_dir, token=args.token)
else:
try:
download_from_cache_server(
jerrychen109 marked this conversation as resolved.
Show resolved Hide resolved
args.model,
args.cache_url,
args.save_dir,
token=args.token,
ignore_cert=args.ignore_cert,
)
except PermissionError:
log.error(f'Not authorized to download {args.model}.')
except Exception as e:
if args.fallback:
log.warn(
f'Failed to download {args.model} from cache server. Falling back to Hugging Face Hub. Error: {e}'
)
download_from_hf_hub(
args.model, save_dir=args.save_dir, token=args.token
)
else:
raise e
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
'triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir_sm90#subdirectory=python',
'boto3>=1.21.45,<2',
'huggingface-hub>=0.17.0,<1.0',
'beautifulsoup4>=4.12.2,<5', # required for model download utils
]

extra_deps = {}
Expand Down
Loading
Loading