diff --git a/lazyllm/components/auto/autodeploy.py b/lazyllm/components/auto/autodeploy.py index 54181e17..29bccbf2 100644 --- a/lazyllm/components/auto/autodeploy.py +++ b/lazyllm/components/auto/autodeploy.py @@ -15,7 +15,7 @@ def __new__(cls, base_model, source=lazyllm.config['model_source'], trust_remote launcher=launchers.remote(ngpus=1), stream=False, type=None, **kw): base_model = ModelDownloader(source).download(base_model) model_name = get_model_name(base_model) - if type == 'embed' or cls.get_model_type(model_name) == 'embed': + if type == 'embed' or ModelDownloader.get_model_type(model_name) == 'embed': return EmbeddingDeploy(trust_remote_code, launcher) map_name = model_map(model_name) candidates = get_configer().query_deploy(lazyllm.config['gpu_type'], launcher.ngpus, @@ -31,10 +31,3 @@ def __new__(cls, base_model, source=lazyllm.config['model_source'], trust_remote return deploy_cls(trust_remote_code=trust_remote_code, launcher=launcher, stream=stream, **kw) raise RuntimeError(f'No valid framework found, candidates are {[c.framework.lower() for c in candidates]}') - @classmethod - def get_model_type(cls, model_name): - from lazyllm.components.utils.downloader.model_mapping import model_name_mapping - if model_name in model_name_mapping: - return model_name_mapping[model_name].get('type', 'llm') - else: - return 'llm' diff --git a/lazyllm/components/auto/autofinetune.py b/lazyllm/components/auto/autofinetune.py index 40b26b21..6413208a 100644 --- a/lazyllm/components/auto/autofinetune.py +++ b/lazyllm/components/auto/autofinetune.py @@ -11,7 +11,7 @@ def __new__(cls, base_model, target_path, source=lazyllm.config['model_source'], batch_size=32, lora_r=8, launcher=launchers.remote(ngpus=1), **kw): base_model = ModelDownloader(source).download(base_model) model_name = get_model_name(base_model) - if cls.get_model_type(model_name) == 'embed': + if ModelDownloader.get_model_type(model_name) == 'embed': raise RuntimeError('Fine-tuning of the embed model is not currently supported.') map_name = model_map(model_name) base_name = model_name.split('-')[0].split('_')[0].lower() @@ -31,11 +31,3 @@ def __new__(cls, base_model, target_path, source=lazyllm.config['model_source'], return finetune_cls(base_model, target_path, merge_path, cp_files='tokeniz*', batch_size=batch_size, lora_r=lora_r, launcher=launcher, **kw) raise RuntimeError(f'No valid framework found, candidates are {[c.framework.lower() for c in candidates]}') - - @classmethod - def get_model_type(cls, model_name): - from lazyllm.components.utils.downloader.model_mapping import model_name_mapping - if model_name in model_name_mapping: - return model_name_mapping[model_name].get('type', 'llm') - else: - return 'llm' diff --git a/lazyllm/components/utils/downloader/model_downloader.py b/lazyllm/components/utils/downloader/model_downloader.py index 24e04d6e..aae9c7e8 100644 --- a/lazyllm/components/utils/downloader/model_downloader.py +++ b/lazyllm/components/utils/downloader/model_downloader.py @@ -19,7 +19,22 @@ def __init__(self, model_source=lazyllm.config['model_source'], self.token = token self.cache_dir = cache_dir self.model_pathes = model_path.split(":") if len(model_path) > 0 else [] - + + @classmethod + def get_model_type(cls, model) ->str: + assert isinstance(model, str) and len(model) > 0, "model name should be a non-empty string" + for name, info in model_name_mapping.items(): + if 'type' not in info: continue + + model_name_set={name.casefold()} + for source in info: + if source == 'type': continue + model_name_set.add(info[source].split('/')[-1].casefold()) + + if model.split(os.sep)[-1].casefold() in model_name_set: + return info['type'] + return 'llm' + def download(self, model=''): assert isinstance(model, str), "model name should be a string." if len(model) == 0 or model[0] in (os.sep, '.', '~'): return model # Dummy or local model. diff --git a/lazyllm/module/module.py b/lazyllm/module/module.py index 554dcae8..a5d861b1 100644 --- a/lazyllm/module/module.py +++ b/lazyllm/module/module.py @@ -10,6 +10,7 @@ import lazyllm from lazyllm import FlatList, LazyLlmResponse, LazyLlmRequest, Option, launchers, LOG from ..components.prompter import PrompterBase, ChatPrompter, EmptyPrompter +from ..components.utils import ModelDownloader from ..flow import FlowBase, Pipeline, Parallel import uuid from ..client import get_redis, redis_client @@ -374,18 +375,19 @@ class TrainableModule(UrlModule): __enable_request__ = False def __init__(self, base_model: Option = '', target_path='', *, stream=False, return_trace=False): + self.base_model = base_model super().__init__(url=None, stream=stream, meta=TrainableModule, return_trace=return_trace) # Fake base_model and target_path for dummy self.target_path = target_path self._train = None # lazyllm.train.auto self._finetune = lazyllm.finetune.auto self._deploy = lazyllm.deploy.auto - - self.base_model = base_model self._deploy_flag = lazyllm.once_flag() # modify default value to '' def prompt(self, prompt=''): + if prompt == '' and ModelDownloader.get_model_type(self.base_model) != 'llm': + prompt = None return super(__class__, self).prompt(prompt) def _get_args(self, arg_cls, disable=[]):