Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to huggingface hub download with revision version #255

Merged
merged 1 commit into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions yolov5/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def load_model(
model_path, device=None, autoshape=True, verbose=False, hf_token: str = None
model_path, device=None, autoshape=True, verbose=False, hf_token: str = None, revision: str = None
):
"""
Creates a specified YOLOv5 model
Expand All @@ -21,6 +21,7 @@ def load_model(
autoshape (bool): make model ready for inference
verbose (bool): if False, yolov5 logs will be silent
hf_token (str): huggingface read token for private models
revision (str): huggingface model revision

Returns:
pytorch model
Expand All @@ -36,7 +37,7 @@ def load_model(

try:
model = DetectMultiBackend(
model_path, device=device, fuse=autoshape, hf_token=hf_token
model_path, device=device, fuse=autoshape, hf_token=hf_token, revision=revision
) # detection model
if autoshape:
if model.pt and isinstance(model.model, ClassificationModel):
Expand Down
4 changes: 2 additions & 2 deletions yolov5/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def forward(self, x):

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True, hf_token=None):
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True, hf_token=None, revision=None):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
Expand All @@ -335,7 +335,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
w = str(weights[0] if isinstance(weights, list) else weights)

# try to dowload from hf hub
result = attempt_download_from_hub(w, hf_token=hf_token)
result = attempt_download_from_hub(w, hf_token=hf_token, revision=revision)
if result is not None:
w = result

Expand Down
4 changes: 3 additions & 1 deletion yolov5/utils/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_model_filename_from_hfhub(repo_id, hf_token=None):
return None


def attempt_download_from_hub(repo_id, hf_token=None):
def attempt_download_from_hub(repo_id, hf_token=None, revision=None):
from huggingface_hub import hf_hub_download, list_repo_files
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
Expand All @@ -161,6 +161,7 @@ def attempt_download_from_hub(repo_id, hf_token=None):
filename=config_file,
repo_type='model',
token=hf_token,
revision=revision,
)

# download model file
Expand All @@ -170,6 +171,7 @@ def attempt_download_from_hub(repo_id, hf_token=None):
filename=model_file,
repo_type='model',
token=hf_token,
revision=revision,
)
return file
except (RepositoryNotFoundError, HFValidationError):
Expand Down
Loading