Skip to content

Commit

Permalink
added huggingface download option
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonMa2016 committed Jun 28, 2023
1 parent 571c0b6 commit d30b64d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
22 changes: 14 additions & 8 deletions liv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from os.path import expanduser
import omegaconf
import hydra
from huggingface_hub import hf_hub_download
import gdown
import torch
from torch.hub import load_state_dict_from_url
import copy
from liv.models.model_liv import LIV

Expand All @@ -29,18 +29,24 @@ def load_liv(modelid='resnet50'):
folderpath = os.path.join(home, modelid)
modelpath = os.path.join(home, modelid, "model.pt")
configpath = os.path.join(home, modelid, "config.yaml")

# Default download from PyTorch
modelurl = 'https://drive.google.com/uc?id=1l1ufzVLxpE5BK7JY6ZnVBljVzmK5c4P3'
configurl = 'https://drive.google.com/uc?id=1GWA5oSJDuHGB2WEdyZZmkro83FNmtaWl'

if not os.path.exists(modelpath):
gdown.download(modelurl, modelpath, quiet=False)
gdown.download(configurl, configpath, quiet=False)
try:
# Default download from GDown
modelurl = 'https://drive.google.com/uc?id=1l1ufzVLxpE5BK7JY6ZnVBljVzmK5c4P3'
configurl = 'https://drive.google.com/uc?id=1GWA5oSJDuHGB2WEdyZZmkro83FNmtaWl'
gdown.download(modelurl, modelpath, quiet=False)
gdown.download(configurl, configpath, quiet=False)
except:
# More reliable download from HuggingFace Hub
hf_hub_download(repo_id="jasonyma/LIV", filename="model.pt", local_dir=folderpath)
hf_hub_download(repo_id="jasonyma/LIV", filename="config.yaml", local_dir=folderpath)

modelcfg = omegaconf.OmegaConf.load(configpath)
cleancfg = cleanup_config(modelcfg)
rep = hydra.utils.instantiate(cleancfg)
rep = torch.nn.DataParallel(rep)
state_dict = torch.load(modelpath, map_location=torch.device(device))['liv']
rep.load_state_dict(state_dict)
return rep
return rep

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def read(fname):
'matplotlib',
'flatten_dict',
'gdown',
'huggingface_hub',
'tabulate',
'pandas',
'scipy',
Expand Down

0 comments on commit d30b64d

Please sign in to comment.