diff --git a/asteroid/utils/hub_utils.py b/asteroid/utils/hub_utils.py index f6295375f..bcd0b08d5 100644 --- a/asteroid/utils/hub_utils.py +++ b/asteroid/utils/hub_utils.py @@ -65,15 +65,15 @@ def cached_download(filename_or_url): else: model_id = filename_or_url revision = None - url = huggingface_hub.hf_hub_url( - model_id, filename=huggingface_hub.PYTORCH_WEIGHTS_NAME, revision=revision - ) - return huggingface_hub.cached_download( - url, + return huggingface_hub.hf_hub_download( + repo_id=model_id, + filename=huggingface_hub.PYTORCH_WEIGHTS_NAME, cache_dir=get_cache_dir(), + revision=revision, library_name="asteroid", library_version=asteroid_version, ) + cached_filename = url_to_filename(url) cached_dir = os.path.join(get_cache_dir(), cached_filename) cached_path = os.path.join(cached_dir, "model.pth")