diff --git a/yolov5/utils/downloads.py b/yolov5/utils/downloads.py index e07acb8..fcd9bdc 100644 --- a/yolov5/utils/downloads.py +++ b/yolov5/utils/downloads.py @@ -93,7 +93,7 @@ def github_assets(repository, version='latest'): return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets # try to download from hf hub - result = attempt_download_from_hub(file, hf_token=hf_token) + result = attempt_download_from_hub(repo, hf_token=hf_token, model_file=file) if result is not None: file = result @@ -145,7 +145,7 @@ def get_model_filename_from_hfhub(repo_id, hf_token=None): return None -def attempt_download_from_hub(repo_id, hf_token=None, revision=None): +def attempt_download_from_hub(repo_id, hf_token=None, revision=None, model_file=None): from huggingface_hub import hf_hub_download, list_repo_files from huggingface_hub.utils._errors import RepositoryNotFoundError from huggingface_hub.utils._validators import HFValidationError @@ -165,7 +165,8 @@ def attempt_download_from_hub(repo_id, hf_token=None, revision=None): ) # download model file - model_file = [f for f in repo_files if f.endswith('.pt')][0] + if model_file is None: + model_file = [f for f in repo_files if f.endswith('.pt')][0] file = hf_hub_download( repo_id=repo_id, filename=model_file,