Skip to content

Commit

Permalink
add stream & history test for engine (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 authored Nov 19, 2024
1 parent 69b8ba4 commit c934d8f
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 14 deletions.
2 changes: 1 addition & 1 deletion lazyllm/components/prompter/builtinPrompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _split_instruction(self, instruction: str):
user_instruction = ""
if LazyLLMPrompterBase.ISA in instruction and LazyLLMPrompterBase.ISE in instruction:
# The instruction includes system prompts and/or user prompts
pattern = re.compile(r"%s(.*)%s" % (LazyLLMPrompterBase.ISA, LazyLLMPrompterBase.ISE))
pattern = re.compile(r"%s(.*)%s" % (LazyLLMPrompterBase.ISA, LazyLLMPrompterBase.ISE), re.DOTALL)
ret = re.split(pattern, instruction)
system_instruction = ret[0]
user_instruction = ret[1]
Expand Down
3 changes: 2 additions & 1 deletion lazyllm/components/prompter/prompter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import json
import collections
from lazyllm import LOG

templates = dict(
# Template used by Alpaca-LoRA.
Expand Down Expand Up @@ -75,7 +76,7 @@ def generate_prompt(self, input, history=None, tools=None, label=None, show=Fals
raise RuntimeError(f'Generate prompt failed, and prompt is {self._prompt}; chat-prompt'
f' is {self._chat_prompt}; input is {input}; history is {history}')
if label: input += label
if self._show or show: print(input)
if self._show or show: LOG.info(input)
return input

def get_response(self, response, input=None):
Expand Down
24 changes: 18 additions & 6 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def start(self, nodes: Dict[str, Any]) -> None:

@overload
def start(self, nodes: List[Dict] = [], edges: List[Dict] = [], resources: List[Dict] = [],
gid: Optional[str] = None, name: Optional[str] = None) -> str:
gid: Optional[str] = None, name: Optional[str] = None, _history_ids: Optional[List[str]] = None) -> str:
...

@overload
Expand Down Expand Up @@ -142,9 +142,10 @@ def _process_hook(self, node, module):


class ServerGraph(lazyllm.ModuleBase):
def __init__(self, g: lazyllm.graph, server: Node, web: Node):
def __init__(self, g: lazyllm.graph, server: Node, web: Node, _history_ids: Optional[List[str]] = None):
super().__init__()
self._g = lazyllm.ActionModule(g)
self._history_ids = _history_ids
if server:
if server.args.get('port'): raise NotImplementedError('Port is not supported now')
self._g = lazyllm.ServerModule(g)
Expand Down Expand Up @@ -205,7 +206,7 @@ def make_server_resource(kind: str, graph: ServerGraph, args: Dict[str, Any]):

@NodeConstructor.register('Graph', 'SubGraph', subitems=['nodes', 'resources'])
def make_graph(nodes: List[dict], edges: List[Union[List[str], dict]] = [],
resources: List[dict] = [], enable_server=True):
resources: List[dict] = [], enable_server: bool = True, _history_ids: Optional[List[str]] = None):
engine = Engine()
server_resources = dict(server=None, web=None)
for resource in resources:
Expand Down Expand Up @@ -238,7 +239,7 @@ def make_graph(nodes: List[dict], edges: List[Union[List[str], dict]] = [],
else:
g.add_edge(engine._nodes[edge['iid']].name, engine._nodes[edge['oid']].name, formatter)

sg = ServerGraph(g, server_resources['server'], server_resources['web'])
sg = ServerGraph(g, server_resources['server'], server_resources['web'], _history_ids=_history_ids)
for kind, node in server_resources.items():
if node:
node.args = dict(kind=kind, graph=sg, args=node.args)
Expand Down Expand Up @@ -430,17 +431,28 @@ def share(self, prompt: str):
def forward(self, *args, **kw):
return self.vqa(*args, **kw)

@property
def stream(self):
return self._vqa._stream

@stream.setter
def stream(self, v: bool):
self._vqa._stream = v


@NodeConstructor.register('VQA')
def make_vqa(base_model: str, file_resource_id: Optional[str] = None):
return VQA(base_model, file_resource_id)


@NodeConstructor.register('SharedLLM')
def make_shared_llm(llm: str, prompt: Optional[str] = None, file_resource_id: Optional[str] = None):
def make_shared_llm(llm: str, prompt: Optional[str] = None, stream: Optional[bool] = None,
file_resource_id: Optional[str] = None):
llm = Engine().build_node(llm).func
if file_resource_id: assert isinstance(llm, VQA), 'file_resource_id is only supported in VQA'
return VQA(llm._vqa.share(prompt=prompt), file_resource_id) if file_resource_id else llm.share(prompt=prompt)
r = VQA(llm._vqa.share(prompt=prompt), file_resource_id) if file_resource_id else llm.share(prompt=prompt)
if stream is not None: r.stream = stream
return r


@NodeConstructor.register('STT')
Expand Down
18 changes: 14 additions & 4 deletions lazyllm/engine/lightengine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .engine import Engine, Node
from .engine import Engine, Node, ServerGraph
import lazyllm
from lazyllm import once_wrapper
from typing import List, Dict, Optional, Set, Union
Expand Down Expand Up @@ -56,7 +56,7 @@ def update_node(self, node):
self._nodes[node.id] = super(__class__, self).build_node(node)
return self._nodes[node.id]

def start(self, nodes, edges=[], resources=[], gid=None, name=None):
def start(self, nodes, edges=[], resources=[], gid=None, name=None, _history_ids=None):
if isinstance(nodes, str):
assert not edges and not resources and not gid and not name
self.build_node(nodes).func.start()
Expand All @@ -65,7 +65,8 @@ def start(self, nodes, edges=[], resources=[], gid=None, name=None):
else:
gid, name = gid or str(uuid.uuid4().hex), name or str(uuid.uuid4().hex)
node = Node(id=gid, kind='Graph', name=name, args=dict(
nodes=copy.copy(nodes), edges=copy.copy(edges), resources=copy.copy(resources)))
nodes=copy.copy(nodes), edges=copy.copy(edges),
resources=copy.copy(resources), _history_ids=_history_ids))
with set_resources(resources):
self.build_node(node).func.start()
return gid
Expand Down Expand Up @@ -106,12 +107,21 @@ def update(self, gid_or_nodes: Union[str, Dict, List[Dict]], nodes: List[Dict],
for node in gid_or_nodes: self.update_node(node)

def run(self, id: str, *args, _lazyllm_files: Optional[Union[str, List[str]]] = None,
_file_resources: Optional[Dict[str, Union[str, List[str]]]] = None, **kw):
_file_resources: Optional[Dict[str, Union[str, List[str]]]] = None,
_lazyllm_history: Optional[List[List[str]]] = None, **kw):
if files := _lazyllm_files:
assert len(args) <= 1 and len(kw) == 0, 'At most one query is enabled when file exists'
args = [lazyllm.formatter.file(formatter='encode')(dict(query=args[0] if args else '', files=files))]
if _file_resources:
lazyllm.globals['lazyllm_files'] = _file_resources
f = self.build_node(id).func
lazyllm.FileSystemQueue().dequeue()
if history := _lazyllm_history:
assert isinstance(f, ServerGraph) and (ids := f._history_ids), 'Only graph can support history'
if not isinstance(history, list) and all([isinstance(h, list) for h in history]):
raise RuntimeError('History shoule be [[str, str], ..., [str, str]] (list of list of str)')
lazyllm.globals['chat_history'] = {Engine().build_node(i).func._module_id: history for i in ids}
result = self.build_node(id).func(*args, **kw)
lazyllm.globals['lazyllm_files'] = {}
lazyllm.globals['chat_history'] = {}
return result
4 changes: 2 additions & 2 deletions lazyllm/engine/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class NodeArgs(object):
init_arguments=dict(
base_model=NodeArgs(str),
target_path=NodeArgs(str),
stream=NodeArgs(bool, True),
stream=NodeArgs(bool, False),
return_trace=NodeArgs(bool, False)),
builder_argument=dict(
trainset=NodeArgs(str),
Expand All @@ -67,7 +67,7 @@ class NodeArgs(object):
base_url=NodeArgs(str),
api_key=NodeArgs(str, None),
secret_key=NodeArgs(str, None),
stream=NodeArgs(bool, True),
stream=NodeArgs(bool, False),
return_trace=NodeArgs(bool, False)),
builder_argument=dict(
prompt=NodeArgs(str)),
Expand Down
9 changes: 9 additions & 0 deletions lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ def __init__(self, base_model: Option = '', target_path='', *, stream=False, ret
None, lazyllm.finetune.auto, lazyllm.deploy.auto)
self._impl._add_father(self)
self.prompt()
self._stream = stream

base_model = property(lambda self: self._impl._base_model)
target_path = property(lambda self: self._impl._target_path)
Expand All @@ -720,6 +721,14 @@ def series(self):
def type(self):
return ModelManager.get_model_type(self.base_model).upper()

@property
def stream(self):
return self._stream

@stream.setter
def stream(self, v: bool):
self._stream = v

def get_all_finetuned_models(self):
return self._impl._get_all_finetuned_models()

Expand Down
8 changes: 8 additions & 0 deletions lazyllm/module/onlineChatModule/onlineChatModuleBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def series(self):
def type(self):
return "LLM"

@property
def stream(self):
return self._stream

@stream.setter
def stream(self, v: bool):
self._stream = v

def prompt(self, prompt=None):
if prompt is None:
self._prompt = ChatPrompter()
Expand Down
35 changes: 35 additions & 0 deletions tests/advanced_tests/standard_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,38 @@ def test_multimedia(self):

r = engine.run(gid, '生成音乐,长笛独奏,大自然之声。')
assert '.wav' in r

def test_stream_and_hostory(self):
resources = [dict(id='0', kind='LocalLLM', name='base', args=dict(base_model='internlm2-chat-7b'))]
nodes = [dict(id='1', kind='SharedLLM', name='m1', args=dict(llm='0', stream=True, prompt=dict(
system='请将我的问题翻译成中文。请注意,请直接输出翻译后的问题,不要反问和发挥',
user='问题: {query} \n, 翻译:'))),
dict(id='2', kind='SharedLLM', name='m2',
args=dict(llm='0', stream=True,
prompt=dict(system='请参考历史对话,回答问题,并保持格式不变。', user='{query}'))),
dict(id='3', kind='JoinFormatter', name='join', args=dict(type='to_dict', names=['query', 'answer'])),
dict(id='4', kind='SharedLLM', stream=False, name='m3',
args=dict(llm='0', prompt=dict(system='你是一个问答机器人,会根据用户的问题作出回答。',
user='请结合历史对话和本轮的问题,总结我们的全部对话。本轮情况如下:\n {query}, 回答: {answer}')))]
engine = LightEngine()
gid = engine.start(nodes, edges=[['__start__', '1'], ['1', '2'], ['1', '3'], ['2', '3'], ['3', '4'],
['4', '__end__']], resources=resources, _history_ids=['2', '4'])
history = [['水的沸点是多少?', '您好,我的答案是:水的沸点在标准大气压下是100摄氏度。'],
['世界上最大的动物是什么?', '您好,我的答案是:蓝鲸是世界上最大的动物。'],
['人一天需要喝多少水?', '您好,我的答案是:一般建议每天喝8杯水,大约2升。'],
['雨后为什么会有彩虹?', '您好,我的答案是:雨后阳光通过水滴发生折射和反射形成了彩虹。'],
['月亮会发光吗?', '您好,我的答案是:月亮本身不会发光,它反射太阳光。'],
['一年有多少天', '您好,我的答案是:一年有365天,闰年有366天。']]

stream_result = ''
with lazyllm.ThreadPoolExecutor(1) as executor:
future = executor.submit(engine.run, gid, 'How many hours are there in a day?', _lazyllm_history=history)
while True:
if value := lazyllm.FileSystemQueue().dequeue():
stream_result += f"{''.join(value)}"
elif future.done():
break
result = future.result()
assert '一天' in stream_result and '小时' in stream_result
assert '您好,我的答案是' in stream_result and '24' in stream_result
assert '蓝鲸' in result and '水' in result
34 changes: 34 additions & 0 deletions tests/charge_tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,37 @@ def test_register_tools(self):
unit = 'Celsius'
ret = engine.run(gid, f"What is the temperature in {city_name} today in {unit}?")
assert city_name in ret and unit in ret and '10' in ret

def test_stream_and_hostory(self):
nodes = [dict(id='1', kind='OnlineLLM', name='m1', args=dict(source='glm', stream=True, prompt=dict(
system='请将我的问题翻译成中文。请注意,请直接输出翻译后的问题,不要反问和发挥',
user='问题: {query} \n, 翻译:'))),
dict(id='2', kind='OnlineLLM', name='m2',
args=dict(source='glm', stream=True,
prompt=dict(system='请参考历史对话,回答问题,并保持格式不变。', user='{query}'))),
dict(id='3', kind='JoinFormatter', name='join', args=dict(type='to_dict', names=['query', 'answer'])),
dict(id='4', kind='OnlineLLM', stream=False, name='m3', args=dict(source='glm', prompt=dict(
system='你是一个问答机器人,会根据用户的问题作出回答。',
user='请结合历史对话和本轮的问题,总结我们的全部对话。本轮情况如下:\n {query}, 回答: {answer}')))]
engine = LightEngine()
gid = engine.start(nodes, edges=[['__start__', '1'], ['1', '2'], ['1', '3'], ['2', '3'],
['3', '4'], ['4', '__end__']], _history_ids=['2', '4'])
history = [['水的沸点是多少?', '您好,我的答案是:水的沸点在标准大气压下是100摄氏度。'],
['世界上最大的动物是什么?', '您好,我的答案是:蓝鲸是世界上最大的动物。'],
['人一天需要喝多少水?', '您好,我的答案是:一般建议每天喝8杯水,大约2升。'],
['雨后为什么会有彩虹?', '您好,我的答案是:雨后阳光通过水滴发生折射和反射形成了彩虹。'],
['月亮会发光吗?', '您好,我的答案是:月亮本身不会发光,它反射太阳光。'],
['一年有多少天', '您好,我的答案是:一年有365天,闰年有366天。']]

stream_result = ''
with lazyllm.ThreadPoolExecutor(1) as executor:
future = executor.submit(engine.run, gid, 'How many hours are there in a day?', _lazyllm_history=history)
while True:
if value := lazyllm.FileSystemQueue().dequeue():
stream_result += f"{''.join(value)}"
elif future.done():
break
result = future.result()
assert '一天' in stream_result and '小时' in stream_result
assert '您好,我的答案是' in stream_result and '24' in stream_result
assert '蓝鲸' in result and '水' in result

0 comments on commit c934d8f

Please sign in to comment.