Skip to content

Commit

Permalink
TrainableModule support multi-args: input, files (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 authored Nov 8, 2024
1 parent 5e97b85 commit 8654413
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 107 deletions.
3 changes: 2 additions & 1 deletion lazyllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
LazyLLMValidateBase, register as component_register, Prompter,
AlpacaPrompter, ChatPrompter, FastapiApp, JsonFormatter, FileFormatter)

from .module import (ModuleBase, UrlModule, TrainableModule, ActionModule,
from .module import (ModuleBase, ModuleBase as Module, UrlModule, TrainableModule, ActionModule,
ServerModule, TrialModule, register as module_register,
OnlineChatModule, OnlineEmbeddingModule, AutoModel)
from .client import redis_client
Expand Down Expand Up @@ -47,6 +47,7 @@

# module
'ModuleBase',
'Module',
'UrlModule',
'TrainableModule',
'ActionModule',
Expand Down
3 changes: 2 additions & 1 deletion lazyllm/components/formatter/formatterbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ def _parse_py_data_by_formatter(self, msg: str):

LAZYLLM_QUERY_PREFIX = '<lazyllm-query>'

def encode_query_with_filepaths(query: str = None, files: List[str] = None) -> str:
def encode_query_with_filepaths(query: str = None, files: Union[str, List[str]] = None) -> str:
query = query if query else ''
query_with_docs = {'query': query, 'files': files}
if files:
if isinstance(files, str): files = [files]
assert isinstance(files, list), "files must be a list."
assert all(isinstance(item, str) for item in files), "All items in files must be strings"
return LAZYLLM_QUERY_PREFIX + json.dumps(query_with_docs)
Expand Down
17 changes: 17 additions & 0 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def _parse_py_data_by_formatter(self, data):

@NodeConstructor.register('JoinFormatter')
def make_join_formatter(type='sum', names=None, symbol=None):
if type == 'file': return make_formatter('file', rule='merge')
return JoinFormatter(type, names=names, symbol=symbol)

@NodeConstructor.register('Formatter')
Expand Down Expand Up @@ -387,3 +388,19 @@ def make_shared_llm(llm: str, prompt: Optional[str] = None):
@NodeConstructor.register('VQA')
def make_vqa(base_model: str):
return lazyllm.TrainableModule(base_model).deploy_method(lazyllm.deploy.LMDeploy)

@NodeConstructor.register('STT')
def make_stt(base_model: str):
# TODO: support multi-files with pictures
def cond(x):
if '<lazyllm-query>' in x:
for ext in ['.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a', '.wma']:
if ext in x or ext.upper() in x:
return True
return False

return lazyllm.ifs(cond, tpath=lazyllm.TrainableModule(base_model), fpath=lazyllm.Identity())

@NodeConstructor.register('Constant')
def make_constant(value: Any):
return (lambda *args, **kw: value)
4 changes: 4 additions & 0 deletions lazyllm/engine/lightengine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .engine import Engine, Node
import lazyllm
from lazyllm import once_wrapper
from typing import List, Dict, Optional, Set, Union
import copy
Expand Down Expand Up @@ -91,4 +92,7 @@ def update(self, gid_or_nodes: Union[str, Dict, List[Dict]], nodes: List[Dict],
for node in gid_or_nodes: self.update_node(node)

def run(self, id: str, *args, **kw):
if files := kw.pop('_lazyllm_files', None):
assert len(args) <= 1 and len(kw) == 0, 'At most one query is enabled when file exists'
args = [lazyllm.formatter.file(formatter='encode')(dict(query=args[0] if args else '', files=files))]
return self.build_node(id).func(*args, **kw)
2 changes: 1 addition & 1 deletion lazyllm/engine/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class NodeArgs(object):
secret_key=NodeArgs(str, None))
)

all_nodes['SD'] = all_nodes['TTS'] = all_nodes['STT'] = dict(
all_nodes['SD'] = all_nodes['TTS'] = dict(
module=lazyllm.TrainableModule,
init_arguments=dict(base_model=NodeArgs(str))
)
Expand Down
4 changes: 3 additions & 1 deletion lazyllm/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def __post_init__(self):
self._nodes[Graph.start_node_name] = Graph.Node(None, Graph.start_node_name)
self._nodes[Graph.end_node_name] = Graph.Node(lazyllm.Identity(), Graph.end_node_name)
self._in_degree = {node: 0 for node in self._nodes.values()}
self._out_degree = {node: 0 for node in self._nodes.values()}
self._sorted_nodes = None

def set_node_arg_name(self, arg_names):
Expand All @@ -529,6 +530,7 @@ def add_edge(self, from_node, to_node, formatter=None):
assert from_node.name not in to_node.inputs, f'Duplicate edges from {from_node.name} to {to_node.name}'
to_node.inputs[from_node.name] = formatter
self._in_degree[to_node] += 1
self._out_degree[from_node] += 1

def topological_sort(self):
in_degree = self._in_degree.copy()
Expand All @@ -546,7 +548,7 @@ def topological_sort(self):
if len(sorted_nodes) != len(self._nodes):
raise ValueError("Graph has a cycle")

return [n for n in sorted_nodes if (self._in_degree[n] > 0 or n.name == Graph.start_node_name)]
return [n for n in sorted_nodes if (self._in_degree[n] > 0 or self._out_degree[n] > 0)]

def compute_node(self, sid, node, intermediate_results, futures):
globals._init_sid(sid)
Expand Down
4 changes: 2 additions & 2 deletions lazyllm/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,8 +639,8 @@ def launch(self, job) -> None:


class RemoteLauncher(LazyLLMLaunchersBase):
def __new__(cls, *args, sync=False, **kwargs):
return getattr(lazyllm.launchers, lazyllm.config['launcher'])(*args, sync=sync, **kwargs)
def __new__(cls, *args, sync=False, ngpus=1, **kwargs):
return getattr(lazyllm.launchers, lazyllm.config['launcher'])(*args, sync=sync, ngpus=ngpus, **kwargs)


def cleanup():
Expand Down
69 changes: 19 additions & 50 deletions lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import functools
from datetime import datetime
from lazyllm import ThreadPoolExecutor, FileSystemQueue
from typing import Dict, List, Any, Union, Optional
from typing import Dict, List, Any, Union, Optional, Tuple

import lazyllm
from lazyllm import FlatList, Option, launchers, LOG, package, kwargs, encode_request, globals
Expand Down Expand Up @@ -90,6 +90,7 @@ def _setattr(v, *, _return_value=self, **kw):
def __call__(self, *args, **kw):
try:
kw.update(globals['global_parameters'].get(self._module_id, dict()))
if (files := globals['lazyllm_files'].get(self._module_id)) is not None: kw['lazyllm_files'] = files
if (history := globals['chat_history'].get(self._module_id)) is not None: kw['llm_chat_history'] = history
r = self.forward(**args[0], **kw) if args and isinstance(args[0], kwargs) else self.forward(*args, **kw)
if self._return_trace:
Expand Down Expand Up @@ -271,24 +272,22 @@ def _set_url(self, url):
# Cannot modify or add any attrubute of self
# prompt keys (excluding history) are in __input (ATTENTION: dict, not kwargs)
# deploy parameters keys are in **kw
def forward(self, __input=package(), *, llm_chat_history=None, lazyllm_files=None, tools=None, stream_output=False, **kw): # noqa C901
def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = package(), # noqa C901
*, llm_chat_history=None, lazyllm_files=None, tools=None, stream_output=False, **kw):
assert self._url is not None, f'Please start {self.__class__} first'
stream_output = stream_output or self._stream
url = self._url

files = []
# p2. specific module_files
if self._module_id in globals['lazyllm_files'] and globals['lazyllm_files'][self._module_id]:
files = globals['lazyllm_files'].pop(self._module_id)
# p1. forward_files
if self.template_message and isinstance(__input, str) and __input.startswith(LAZYLLM_QUERY_PREFIX):
deinput = decode_query_with_filepaths(__input)
__input = deinput['query']
if deinput['files']:
files = deinput['files']
# p0. bind_files
if lazyllm_files:
files = _lazyllm_get_file_list(lazyllm_files)
if self.template_message:
if isinstance(__input, package):
assert not lazyllm_files, 'Duplicate `files` argument provided by args and kwargs'
__input, lazyllm_files = __input
if isinstance(__input, str) and __input.startswith(LAZYLLM_QUERY_PREFIX):
assert not lazyllm_files, 'Argument `files` is already provided by query'
deinput = decode_query_with_filepaths(__input)
__input, files = deinput['query'], deinput['files']
else:
files = _lazyllm_get_file_list(lazyllm_files) if lazyllm_files else []

query = __input
__input = self._prompt.generate_prompt(query, llm_chat_history, tools)
Expand Down Expand Up @@ -393,6 +392,11 @@ def _modify_parameters(self, paras, kw):
def set_default_parameters(self, **kw):
self._modify_parameters(self.template_message, kw)

def __call__(self, *args, **kw):
if len(args) > 1:
return super(__class__, self).__call__(package(args), **kw)
return super(__class__, self).__call__(*args, **kw)

def __repr__(self):
return lazyllm.make_repr('Module', 'Url', name=self._module_name, url=self._url,
stream=self._stream, return_trace=self._return_trace)
Expand Down Expand Up @@ -481,11 +485,6 @@ def __init__(self, m, pre=None, post=None, stream=False, return_trace=False,

_url_id = property(lambda self: self._impl._module_id)

def __call__(self, *args, **kw):
if len(args) > 1:
return super(__class__, self).__call__(package(args), **kw)
return super(__class__, self).__call__(*args, **kw)

def wait(self):
self._impl._launcher.wait()

Expand Down Expand Up @@ -861,36 +860,6 @@ def share(self, prompt=None, format=None):
new._impl._add_father(new)
return new

class Module(object):
# modules(list of modules) -> ActionModule
# action(lazyllm.flow) -> ActionModule
# url(str) -> UrlModule
# base_model(str) & target_path(str)-> TrainableModule
def __new__(self, *args, **kw):
if len(args) >= 1 and isinstance(args[0], Module):
return ActionModule(*args)
elif len(args) == 1 and isinstance(args[0], list) and isinstance(args[0][0], Module):
return ActionModule(*args[0])
elif len(args) == 0 and 'modules' in kw:
return ActionModule(kw['modules'])
elif len(args) == 1 and isinstance(args[0], FlowBase):
return ActionModule(args[0])
elif len(args) == 0 and 'action' in kw:
return ActionModule(kw['modules'])
elif len(args) == 1 and isinstance(args[0], str):
return UrlModule(url=args[0])
elif len(args) == 0 and 'url' in kw:
return UrlModule(url=kw['url'])
elif ...:
return TrainableModule()

@classmethod
def action(cls, *args, **kw): return ActionModule(*args, **kw)
@classmethod
def url(cls, *args, **kw): return UrlModule(*args, **kw)
@classmethod
def trainable(cls, *args, **kw): return TrainableModule(*args, **kw)


class ModuleRegistryBase(ModuleBase, metaclass=lazyllm.LazyLLMRegisterMetaClass):
__reg_overwrite__ = 'forward'
Expand Down
2 changes: 1 addition & 1 deletion tests/advanced_tests/standard_test/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_vlm_and_lmdeploy(self):

globals['lazyllm_files'][chat._module_id] = [pig_path]
assert '猪' in m(query)
globals['lazyllm_files'][chat._module_id] = [pig_path]
globals['lazyllm_files'][chat._module_id] = None
assert '鸡' in m(f'<lazyllm-query>{{"query":"{query}","files":["{ji_path}"]}}')

_, client = self.warp_into_web(m)
Expand Down
76 changes: 26 additions & 50 deletions tests/advanced_tests/standard_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import lazyllm
import os
from lazyllm.engine import LightEngine

class TestEngine(object):
Expand All @@ -11,27 +11,6 @@ def _test_vqa(self):
engine = LightEngine()
engine.start(node, edge, resource)

def _test_multimedia(self):
painter_p = 'Now you are a master of drawing prompts, capable of converting any Chinese content entered by the user into English drawing prompts. In this task, you need to convert any input content into English drawing prompts, and you can enrich and expand the prompt content.' # noqa E501
musician_p = 'Now you are a master of music composition prompts, capable of converting any Chinese content entered by the user into English music composition prompts. In this task, you need to convert any input content into English music composition prompts, and you can enrich and expand the prompt content.' # noqa E501

resources = [dict(id='0', kind='LocalLLM', name='base', args=dict(base_model='internlm2-chat-7b')),
dict(id='1', kind='web', name='web', args=dict(port=None, title='多模态聊天机器人', audio=True))]
nodes = [dict(id='2', kind='Intention', name='intent', args=dict(base_model='0', nodes={
'Chat': dict(id='3', kind='SharedLLM', name='chat', args=dict(llm='0')),
'Speech Recognition': dict(id='4', kind='STT', name='stt', args=dict(base_model='SenseVoiceSmall')),
'Image QA': dict(id='5', kind='VQA', name='vqa', args=dict(base_model='Mini-InternVL-Chat-2B-V1-5')),
'Drawing': [dict(id='6', kind='SharedLLM', name='drow_prompt', args=dict(llm='0', prompt=painter_p)),
dict(id='7', kind='SD', name='sd', args=dict(base_model='stable-diffusion-3-medium'))],
'Generate Music': [dict(id='8', kind='SharedLLM', name='translate', args=dict(llm='0', prompt=musician_p)),
dict(id='9', kind='TTS', name='music', args=dict(base_model='musicgen-small'))],
'Text to Speech': dict(id='10', kind='TTS', name='speech', args=dict(base_model='ChatTTS')),
}))]

edges = [dict(iid="__start__", oid="2"), dict(iid="2", oid="__end__")]
engine = LightEngine()
engine.start(nodes, edges, resources)

def test_http(self):
nodes = [
dict(
Expand All @@ -54,44 +33,39 @@ def test_http(self):
ret = engine.run(gid)
assert '商汤科技' in ret['content']

def test_multimedia2(self):
def test_multimedia(self):
painter_prompt = 'Now you are a master of drawing prompts, capable of converting any Chinese content entered by the user into English drawing prompts. In this task, you need to convert any input content into English drawing prompts, and you can enrich and expand the prompt content.' # noqa E501
musician_prompt = 'Now you are a master of music composition prompts, capable of converting any Chinese content entered by the user into English music composition prompts. In this task, you need to convert any input content into English music composition prompts, and you can enrich and expand the prompt content.' # noqa E501
translator_prompt = 'Now you are a master of translation prompts, capable of converting any Chinese content entered by the user into English translation prompts. In this task, you need to convert any input content into English translation prompts, and you can enrich and expand the prompt content.' # noqa E501

resources = [dict(id='0', kind='LocalLLM', name='base', args=dict(base_model='internlm2-chat-7b')),
dict(id='1', kind='web', name='web', args=dict(port=None, title='多模态聊天机器人', audio=True))]
resources = [dict(id='llm', kind='LocalLLM', name='base', args=dict(base_model='internlm2-chat-7b')),
dict(id='vqa', kind='VQA', name='vqa', args=dict(base_model='Mini-InternVL-Chat-2B-V1-5')),
dict(id='web', kind='web', name='web', args=dict(port=None, title='多模态聊天机器人', audio=True))]

nodes1 = [
dict(id='2', kind='SharedLLM', name='draw_prompt', args=dict(llm='0', prompt=painter_prompt)),
dict(id='2', kind='SharedLLM', name='draw_prompt', args=dict(llm='llm', prompt=painter_prompt)),
dict(id='3', kind='SD', name='sd', args=dict(base_model='stable-diffusion-3-medium')),
dict(id='4', kind='Code', name='vqa_query', args='def static_str(x):\n return "描述图片"\n'),
dict(id='5', kind='Formatter', name='merge_sd_vqa1', args=dict(ftype='file', rule='merge')),
dict(id='6', kind='VQA', name='vqa', args=dict(base_model='Mini-InternVL-Chat-2B-V1-5')),
dict(id='7', kind='Formatter', name='merge_sd_vqa2', args=dict(ftype='file', rule='merge')),
dict(id='4', kind='Constant', name='vqa_query', args='描述图片'),
dict(id='5', kind='SharedLLM', name='vqa1', args=dict(llm='vqa')),
dict(id='6', kind='JoinFormatter', name='merge_sd_vqa2', args=dict(type='file')),
]
edges1 = [
dict(iid='__start__', oid='2'), dict(iid='7', oid='__end__'),
dict(iid="2", oid="3"), dict(iid="3", oid="4"), dict(iid="3", oid="5"),
dict(iid="4", oid="5"), dict(iid="5", oid="6"), dict(iid="3", oid="7"), dict(iid="6", oid="7"),
dict(iid='__start__', oid='2'), dict(iid='6', oid='__end__'), dict(iid="2", oid="3"),
dict(iid="4", oid="5"), dict(iid="3", oid="5"), dict(iid="3", oid="6"), dict(iid="5", oid="6"),
]

speech_recog = dict(id='8', kind='STT', name='stt', args=dict(base_model='SenseVoiceSmall'))
ident = dict(id='9', kind='Code', name='ident', args='def ident(x):\n return x\n')
nodes = [dict(id='10', kind='Formatter', name='encode_input', args=dict(ftype='file', rule='encode')),
dict(id='11', kind='Ifs', name='voice_or_txt', args=dict(
cond='def cond(x): return "<lazyllm-query>" in x', true=[speech_recog], false=[ident])),
dict(id='12', kind='Intention', name='intent', args=dict(base_model='0', nodes={
'Drawing': dict(id='14', kind='SubGraph', name='draw_vqa', args=dict(nodes=nodes1, edges=edges1)),
'Translate': dict(id='15', kind='SharedLLM', name='translate_prompt',
args=dict(llm='0', prompt=translator_prompt)),
'Generate Music': [dict(id='16', kind='SharedLLM', name='translate',
args=dict(llm='0', prompt=musician_prompt)),
dict(id='17', kind='TTS', name='music',
nodes = [dict(id='7', kind='STT', name='stt', args=dict(base_model='SenseVoiceSmall')),
dict(id='8', kind='Intention', name='intent', args=dict(base_model='llm', nodes={
'Drawing': dict(id='9', kind='SubGraph', name='draw_vqa', args=dict(nodes=nodes1, edges=edges1)),
'Translate': dict(id='10', kind='SharedLLM', name='translate_prompt',
args=dict(llm='llm', prompt=translator_prompt)),
'Generate Music': [dict(id='11', kind='SharedLLM', name='translate',
args=dict(llm='llm', prompt=musician_prompt)),
dict(id='12', kind='TTS', name='music',
args=dict(base_model='musicgen-small'))],
'Chat': dict(id='18', kind='SharedLLM', name='chat', args=dict(llm='0'))}))]
edges = [dict(iid="__start__", oid="10"), dict(iid="10", oid="11"),
dict(iid="11", oid="12"), dict(iid="12", oid="__end__")]
'Image Question Answering': dict(id='13', kind='SharedLLM', name='vqa2', args=dict(llm='vqa')),
'Chat': dict(id='14', kind='SharedLLM', name='chat', args=dict(llm='llm'))}))]
edges = [dict(iid="__start__", oid="7"), dict(iid="7", oid="8"), dict(iid="8", oid="__end__")]

engine = LightEngine()
gid = engine.start(nodes, edges, resources)
Expand All @@ -102,10 +76,12 @@ def test_multimedia2(self):
r = engine.run(gid, '翻译:我喜欢敲代码。')
assert 'code' in r

audio_path = os.path.join(lazyllm.config['data_path'], 'ci_data/draw_pig.mp3')
r = engine.run(gid, {"query": "", "files": [f"{audio_path}"]})
r = engine.run(gid, "", _lazyllm_files=os.path.join(lazyllm.config['data_path'], 'ci_data/draw_pig.mp3'))
assert '.png' in r

r = engine.run(gid, "这张图片描述的是什么?", _lazyllm_files=os.path.join(lazyllm.config['data_path'], 'ci_data/ji.jpg'))
assert '鸡' in r or 'chicken' in r

r = engine.run(gid, '你好,很高兴认识你')
assert '你好' in r

Expand Down

0 comments on commit 8654413

Please sign in to comment.