diff --git a/lazyllm/components/utils/downloader/model_downloader.py b/lazyllm/components/utils/downloader/model_downloader.py index 28c726e8..7e7863f0 100644 --- a/lazyllm/components/utils/downloader/model_downloader.py +++ b/lazyllm/components/utils/downloader/model_downloader.py @@ -26,6 +26,13 @@ def __init__(self, model_source=lazyllm.config['model_source'], self.token = token or None self.cache_dir = cache_dir self.model_paths = model_path.split(":") if len(model_path) > 0 else [] + if self.model_source == 'huggingface': + self.hub_downloader = HuggingfaceDownloader(token=self.token) + else: + self.hub_downloader = ModelscopeDownloader(token=self.token) + if self.model_source != 'modelscope': + lazyllm.LOG.warning("Only support Huggingface and Modelscope currently. " + f"Unsupported model source: {self.model_source}. Forcing use of Modelscope.") @classmethod def get_model_type(cls, model) -> str: @@ -116,22 +123,10 @@ def download(self, model='', call_back=None): return model_save_dir def validate_token(self): - if self.model_source == 'huggingface': - return HuggingfaceDownloader(token=self.token).verify_hub_token() - elif self.model_source == 'modelscope': - return ModelscopeDownloader(token=self.token).verify_hub_token() - else: - lazyllm.LOG.warning("Only support Huggingface and Modelscope currently.") - return False + return self.hub_downloader.verify_hub_token() def validate_model_id(self, model_id): - if self.model_source == 'huggingface': - return HuggingfaceDownloader(token=self.token).verify_model_id(model_id) - elif self.model_source == 'modelscope': - return ModelscopeDownloader(token=self.token).verify_model_id(model_id) - else: - lazyllm.LOG.warning("Only support Huggingface and Modelscope currently.") - return False + return self.hub_downloader.verify_model_id(model_id) def _model_exists_at_path(self, model_name): if len(self.model_paths) == 0: @@ -166,10 +161,7 @@ def _do_download(self, model='', call_back=None): full_model_dir = os.path.join(self.cache_dir, self.model_source, model_dir) try: - if self.model_source == 'huggingface': - return HuggingfaceDownloader(call_back, self.token).download(model, full_model_dir) - elif self.model_source == 'modelscope': - return ModelscopeDownloader(call_back, self.token).download(model, full_model_dir) + self.hub_downloader.download(model, full_model_dir, call_back) # Use `BaseException` to capture `KeyboardInterrupt` and normal `Exceptioin`. except BaseException as e: lazyllm.LOG.warning(f"Download encountered an error: {e}") @@ -184,9 +176,8 @@ def _do_download(self, model='', call_back=None): class HubDownloader(ABC): - def __init__(self, call_back=None, token=None): + def __init__(self, token=None): self._token = token if self._verify_hub_token(token) else None - self._call_back = call_back self._api = self._build_hub_api(self._token) @abstractmethod @@ -209,13 +200,13 @@ def _do_download(self, model_id, model_dir): def _get_repo_files(self, model_id): pass - def _polling_progress(self, model_dir, total, polling_event): + def _polling_progress(self, model_dir, total, polling_event, call_back): while not polling_event.is_set(): n = self._get_current_files_size(model_dir) n = min(n, total) - if callable(self._call_back): + if callable(call_back): try: - self._call_back(n, total) + call_back(n, total) except Exception as e: print(f"Error in callback: {e}") time.sleep(1) @@ -235,15 +226,16 @@ def _get_files_total_size(self, hub_model_info): size += item['Size'] return size - def download(self, model_id, model_dir): + def download(self, model_id, model_dir, call_back=None): total = self._get_files_total_size(self._get_repo_files(model_id)) - if self._call_back: + if call_back: polling_event = threading.Event() - polling_thread = threading.Thread(target=self._polling_progress, args=(model_dir, total, polling_event)) + polling_thread = threading.Thread(target=self._polling_progress, + args=(model_dir, total, polling_event, call_back)) polling_thread.daemon = True polling_thread.start() downloaded_path = self._do_download(model_id, model_dir) - if self._call_back and polling_thread: + if call_back and polling_thread: polling_event.set() polling_thread.join() return downloaded_path