Skip to content

Commit

Permalink
add stream & history test for engine
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 committed Nov 18, 2024
1 parent f07a9e1 commit 1f6bb6b
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 6 deletions.
3 changes: 2 additions & 1 deletion lazyllm/components/prompter/prompter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import json
import collections
from lazyllm import LOG

templates = dict(
# Template used by Alpaca-LoRA.
Expand Down Expand Up @@ -75,7 +76,7 @@ def generate_prompt(self, input, history=None, tools=None, label=None, show=Fals
raise RuntimeError(f'Generate prompt failed, and prompt is {self._prompt}; chat-prompt'
f' is {self._chat_prompt}; input is {input}; history is {history}')
if label: input += label
if self._show or show: print(input)
if self._show or show: LOG.info(input)
return input

def get_response(self, response, input=None):
Expand Down
15 changes: 13 additions & 2 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,17 +430,28 @@ def share(self, prompt: str):
def forward(self, *args, **kw):
return self.vqa(*args, **kw)

@property
def stream(self):
return self._vqa._stream

@stream.setter
def stream(self, v: bool):
self._vqa._stream = v


@NodeConstructor.register('VQA')
def make_vqa(base_model: str, file_resource_id: Optional[str] = None):
return VQA(base_model, file_resource_id)


@NodeConstructor.register('SharedLLM')
def make_shared_llm(llm: str, prompt: Optional[str] = None, file_resource_id: Optional[str] = None):
def make_shared_llm(llm: str, prompt: Optional[str] = None, stream: Optional[bool] = None,
file_resource_id: Optional[str] = None):
llm = Engine().build_node(llm).func
if file_resource_id: assert isinstance(llm, VQA), 'file_resource_id is only supported in VQA'
return VQA(llm._vqa.share(prompt=prompt), file_resource_id) if file_resource_id else llm.share(prompt=prompt)
r = VQA(llm._vqa.share(prompt=prompt), file_resource_id) if file_resource_id else llm.share(prompt=prompt)
if stream is not None: r.stream = stream
return r


@NodeConstructor.register('STT')
Expand Down
9 changes: 9 additions & 0 deletions lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ def __init__(self, base_model: Option = '', target_path='', *, stream=False, ret
None, lazyllm.finetune.auto, lazyllm.deploy.auto)
self._impl._add_father(self)
self.prompt()
self._stream = stream

base_model = property(lambda self: self._impl._base_model)
target_path = property(lambda self: self._impl._target_path)
Expand All @@ -720,6 +721,14 @@ def series(self):
def type(self):
return ModelManager.get_model_type(self.base_model).upper()

@property
def stream(self):
return self._stream

@stream.setter
def stream(self, v: bool):
self._stream = v

def get_all_finetuned_models(self):
return self._impl._get_all_finetuned_models()

Expand Down
8 changes: 8 additions & 0 deletions lazyllm/module/onlineChatModule/onlineChatModuleBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def series(self):
def type(self):
return "LLM"

@property
def stream(self):
return self._stream

@stream.setter
def stream(self, v: bool):
self._stream = v

def prompt(self, prompt=None):
if prompt is None:
self._prompt = ChatPrompter()
Expand Down
37 changes: 37 additions & 0 deletions tests/advanced_tests/standard_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,40 @@ def test_multimedia(self):

r = engine.run(gid, '生成音乐,长笛独奏,大自然之声。')
assert '.wav' in r

def test_stream_and_hostory(self):
resources = [dict(id='0', kind='LocalLLM', name='base', args=dict(base_model='internlm2-chat-7b'))]
nodes = [dict(id='1', kind='SharedLLM', name='m1', args=dict(llm='0', stream=True, prompt=dict(
system='请将我的问题翻译成中文。请注意,请直接输出翻译后的问题,不要反问和发挥',
user='问题: {query} \n, 翻译:'))),
dict(id='2', kind='SharedLLM', name='m2',
args=dict(llm='0', stream=True,
prompt=dict(system='请参考历史对话,回答问题,并保持格式不变。', user='{query}'))),
dict(id='3', kind='JoinFormatter', name='join', args=dict(type='to_dict', names=['query', 'answer'])),
dict(id='4', kind='SharedLLM', stream=False, name='m3',
args=dict(llm='0', prompt=dict(system='你是一个问答机器人,会根据用户的问题作出回答。',
user='请结合历史对话和本轮的问题,总结我们的全部对话。本轮情况如下:\n {query}, 回答: {answer}')))]
engine = LightEngine()
# TODO handle duplicated node id
gid = engine.start(nodes, edges=[['__start__', '1'], ['1', '2'], ['1', '3'],
['2', '3'], ['3', '4'], ['4', '__end__']], resources=resources)
history = [['水的沸点是多少?', '您好,我的答案是:水的沸点在标准大气压下是100摄氏度。'],
['世界上最大的动物是什么?', '您好,我的答案是:蓝鲸是世界上最大的动物。'],
['人一天需要喝多少水?', '您好,我的答案是:一般建议每天喝8杯水,大约2升。'],
['雨后为什么会有彩虹?', '您好,我的答案是:雨后阳光通过水滴发生折射和反射形成了彩虹。'],
['月亮会发光吗?', '您好,我的答案是:月亮本身不会发光,它反射太阳光。'],
['一年有多少天', '您好,我的答案是:一年有365天,闰年有366天。']]

stream_result = ''
with lazyllm.ThreadPoolExecutor(1) as executor:
future = executor.submit(engine.run, gid, 'How many hours are there in a day?',
_llm_history=dict(ids=['2', '4'], history=history))
while True:
if value := lazyllm.FileSystemQueue().dequeue():
stream_result += f"{''.join(value)}"
elif future.done():
break
result = future.result()
assert '一天' in stream_result and '小时' in stream_result
assert '您好,我的答案是' in stream_result and '24' in stream_result
assert '蓝鲸' in result and '水' in result
6 changes: 3 additions & 3 deletions tests/charge_tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def test_stream_and_hostory(self):
args=dict(source='glm', stream=True,
prompt=dict(system='请参考历史对话,回答问题,并保持格式不变。', user='{query}'))),
dict(id='3', kind='JoinFormatter', name='join', args=dict(type='to_dict', names=['query', 'answer'])),
dict(id='4', kind='OnlineLLM', stream=False, name='m3',
args=dict(source='glm', prompt=dict(system='你是一个问答机器人,会根据用户的问题作出回答。',
user='请结合本轮的问题,总结我们的全部对话。问题: {query}, 回答: {answer}')))]
dict(id='4', kind='OnlineLLM', stream=False, name='m3', args=dict(source='glm', prompt=dict(
system='你是一个问答机器人,会根据用户的问题作出回答。',
user='请结合历史对话和本轮的问题,总结我们的全部对话。本轮情况如下:\n {query}, 回答: {answer}')))]
engine = LightEngine()
# TODO handle duplicated node id
gid = engine.start(nodes, edges=[['__start__', '1'], ['1', '2'], ['1', '3'],
Expand Down

0 comments on commit 1f6bb6b

Please sign in to comment.