Skip to content

Commit

Permalink
Merge branch 'model_type' into 'main'
Browse files Browse the repository at this point in the history
add model type getter

See merge request tps-llm/lazyllm!141
  • Loading branch information
wzh1994 committed Jun 12, 2024
2 parents 31fd78f + 5f6f79d commit 812959a
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
9 changes: 1 addition & 8 deletions lazyllm/components/auto/autodeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __new__(cls, base_model, source=lazyllm.config['model_source'], trust_remote
launcher=launchers.remote(ngpus=1), stream=False, type=None, **kw):
base_model = ModelDownloader(source).download(base_model)
model_name = get_model_name(base_model)
if type == 'embed' or cls.get_model_type(model_name) == 'embed':
if type == 'embed' or ModelDownloader.get_model_type(model_name) == 'embed':
return EmbeddingDeploy(trust_remote_code, launcher)
map_name = model_map(model_name)
candidates = get_configer().query_deploy(lazyllm.config['gpu_type'], launcher.ngpus,
Expand All @@ -31,10 +31,3 @@ def __new__(cls, base_model, source=lazyllm.config['model_source'], trust_remote
return deploy_cls(trust_remote_code=trust_remote_code, launcher=launcher, stream=stream, **kw)
raise RuntimeError(f'No valid framework found, candidates are {[c.framework.lower() for c in candidates]}')

@classmethod
def get_model_type(cls, model_name):
from lazyllm.components.utils.downloader.model_mapping import model_name_mapping
if model_name in model_name_mapping:
return model_name_mapping[model_name].get('type', 'llm')
else:
return 'llm'
10 changes: 1 addition & 9 deletions lazyllm/components/auto/autofinetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __new__(cls, base_model, target_path, source=lazyllm.config['model_source'],
batch_size=32, lora_r=8, launcher=launchers.remote(ngpus=1), **kw):
base_model = ModelDownloader(source).download(base_model)
model_name = get_model_name(base_model)
if cls.get_model_type(model_name) == 'embed':
if ModelDownloader.get_model_type(model_name) == 'embed':
raise RuntimeError('Fine-tuning of the embed model is not currently supported.')
map_name = model_map(model_name)
base_name = model_name.split('-')[0].split('_')[0].lower()
Expand All @@ -31,11 +31,3 @@ def __new__(cls, base_model, target_path, source=lazyllm.config['model_source'],
return finetune_cls(base_model, target_path, merge_path, cp_files='tokeniz*',
batch_size=batch_size, lora_r=lora_r, launcher=launcher, **kw)
raise RuntimeError(f'No valid framework found, candidates are {[c.framework.lower() for c in candidates]}')

@classmethod
def get_model_type(cls, model_name):
from lazyllm.components.utils.downloader.model_mapping import model_name_mapping
if model_name in model_name_mapping:
return model_name_mapping[model_name].get('type', 'llm')
else:
return 'llm'
17 changes: 16 additions & 1 deletion lazyllm/components/utils/downloader/model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,22 @@ def __init__(self, model_source=lazyllm.config['model_source'],
self.token = token
self.cache_dir = cache_dir
self.model_pathes = model_path.split(":") if len(model_path) > 0 else []


@classmethod
def get_model_type(cls, model) ->str:
assert isinstance(model, str) and len(model) > 0, "model name should be a non-empty string"
for name, info in model_name_mapping.items():
if 'type' not in info: continue

model_name_set={name.casefold()}
for source in info:
if source == 'type': continue
model_name_set.add(info[source].split('/')[-1].casefold())

if model.split(os.sep)[-1].casefold() in model_name_set:
return info['type']
return 'llm'

def download(self, model=''):
assert isinstance(model, str), "model name should be a string."
if len(model) == 0 or model[0] in (os.sep, '.', '~'): return model # Dummy or local model.
Expand Down
6 changes: 4 additions & 2 deletions lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lazyllm
from lazyllm import FlatList, LazyLlmResponse, LazyLlmRequest, Option, launchers, LOG
from ..components.prompter import PrompterBase, ChatPrompter, EmptyPrompter
from ..components.utils import ModelDownloader
from ..flow import FlowBase, Pipeline, Parallel
import uuid
from ..client import get_redis, redis_client
Expand Down Expand Up @@ -374,18 +375,19 @@ class TrainableModule(UrlModule):
__enable_request__ = False

def __init__(self, base_model: Option = '', target_path='', *, stream=False, return_trace=False):
self.base_model = base_model
super().__init__(url=None, stream=stream, meta=TrainableModule, return_trace=return_trace)
# Fake base_model and target_path for dummy
self.target_path = target_path
self._train = None # lazyllm.train.auto
self._finetune = lazyllm.finetune.auto
self._deploy = lazyllm.deploy.auto

self.base_model = base_model
self._deploy_flag = lazyllm.once_flag()

# modify default value to ''
def prompt(self, prompt=''):
if prompt == '' and ModelDownloader.get_model_type(self.base_model) != 'llm':
prompt = None
return super(__class__, self).prompt(prompt)

def _get_args(self, arg_cls, disable=[]):
Expand Down

0 comments on commit 812959a

Please sign in to comment.