Skip to content

Commit

Permalink
Review3: rm callback from self, And Init hub_downloader in init of mo…
Browse files Browse the repository at this point in the history
…del_manager
  • Loading branch information
JingofXin committed Dec 19, 2024
1 parent 84fd6c5 commit d238337
Showing 1 changed file with 19 additions and 27 deletions.
46 changes: 19 additions & 27 deletions lazyllm/components/utils/downloader/model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def __init__(self, model_source=lazyllm.config['model_source'],
self.token = token or None
self.cache_dir = cache_dir
self.model_paths = model_path.split(":") if len(model_path) > 0 else []
if self.model_source == 'huggingface':
self.hub_downloader = HuggingfaceDownloader(token=self.token)
else:
self.hub_downloader = ModelscopeDownloader(token=self.token)
if self.model_source != 'modelscope':
lazyllm.LOG.warning("Only support Huggingface and Modelscope currently. "
f"Unsupported model source: {self.model_source}. Forcing use of Modelscope.")

@classmethod
def get_model_type(cls, model) -> str:
Expand Down Expand Up @@ -116,22 +123,10 @@ def download(self, model='', call_back=None):
return model_save_dir

def validate_token(self):
if self.model_source == 'huggingface':
return HuggingfaceDownloader(token=self.token).verify_hub_token()
elif self.model_source == 'modelscope':
return ModelscopeDownloader(token=self.token).verify_hub_token()
else:
lazyllm.LOG.warning("Only support Huggingface and Modelscope currently.")
return False
return self.hub_downloader.verify_hub_token()

def validate_model_id(self, model_id):
if self.model_source == 'huggingface':
return HuggingfaceDownloader(token=self.token).verify_model_id(model_id)
elif self.model_source == 'modelscope':
return ModelscopeDownloader(token=self.token).verify_model_id(model_id)
else:
lazyllm.LOG.warning("Only support Huggingface and Modelscope currently.")
return False
return self.hub_downloader.verify_model_id(model_id)

def _model_exists_at_path(self, model_name):
if len(self.model_paths) == 0:
Expand Down Expand Up @@ -166,10 +161,7 @@ def _do_download(self, model='', call_back=None):
full_model_dir = os.path.join(self.cache_dir, self.model_source, model_dir)

try:
if self.model_source == 'huggingface':
return HuggingfaceDownloader(call_back, self.token).download(model, full_model_dir)
elif self.model_source == 'modelscope':
return ModelscopeDownloader(call_back, self.token).download(model, full_model_dir)
self.hub_downloader.download(model, full_model_dir, call_back)
# Use `BaseException` to capture `KeyboardInterrupt` and normal `Exceptioin`.
except BaseException as e:
lazyllm.LOG.warning(f"Download encountered an error: {e}")
Expand All @@ -184,9 +176,8 @@ def _do_download(self, model='', call_back=None):

class HubDownloader(ABC):

def __init__(self, call_back=None, token=None):
def __init__(self, token=None):
self._token = token if self._verify_hub_token(token) else None
self._call_back = call_back
self._api = self._build_hub_api(self._token)

@abstractmethod
Expand All @@ -209,13 +200,13 @@ def _do_download(self, model_id, model_dir):
def _get_repo_files(self, model_id):
pass

def _polling_progress(self, model_dir, total, polling_event):
def _polling_progress(self, model_dir, total, polling_event, call_back):
while not polling_event.is_set():
n = self._get_current_files_size(model_dir)
n = min(n, total)
if callable(self._call_back):
if callable(call_back):
try:
self._call_back(n, total)
call_back(n, total)
except Exception as e:
print(f"Error in callback: {e}")
time.sleep(1)
Expand All @@ -235,15 +226,16 @@ def _get_files_total_size(self, hub_model_info):
size += item['Size']
return size

def download(self, model_id, model_dir):
def download(self, model_id, model_dir, call_back=None):
total = self._get_files_total_size(self._get_repo_files(model_id))
if self._call_back:
if call_back:
polling_event = threading.Event()
polling_thread = threading.Thread(target=self._polling_progress, args=(model_dir, total, polling_event))
polling_thread = threading.Thread(target=self._polling_progress,
args=(model_dir, total, polling_event, call_back))
polling_thread.daemon = True
polling_thread.start()
downloaded_path = self._do_download(model_id, model_dir)
if self._call_back and polling_thread:
if call_back and polling_thread:
polling_event.set()
polling_thread.join()
return downloaded_path
Expand Down

0 comments on commit d238337

Please sign in to comment.