Skip to content

Commit

Permalink
add sharedllm and switch tests (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 authored Sep 11, 2024
1 parent f68d0b2 commit 175c12f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
4 changes: 4 additions & 0 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,7 @@ def make_fc(llm: str, tools: List[str], algorithm: Optional[str] = None):
lazyllm.tools.ReWOOAgent if algorithm == 'ReWOO' else \
lazyllm.tools.ReactAgent if algorithm == 'React' else lazyllm.tools.FunctionCallAgent
return f(Engine().build_node(llm).func, tools)

@NodeConstructor.register('SharedLLM')
def make_shared_llm(llm: str, prompt: Optional[str] = None):
return Engine().build_node(llm).func.share(prompt=prompt)
20 changes: 18 additions & 2 deletions tests/basic_tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ def run_around_tests(self):
LightEngine().reset()

def test_engine_subgraph(self):
nodes = [dict(id='1', kind='LocalLLM', name='m1', args=dict(base_model='', deploy_method='dummy'))]
resources = [dict(id='0', kind='LocalLLM', name='m1', args=dict(base_model='', deploy_method='dummy'))]
nodes = [dict(id='1', kind='SharedLLM', name='s1', args=dict(llm='0', prompt=None))]
edges = [dict(iid='__start__', oid='1'), dict(iid='1', oid='__end__')]

nodes = [dict(id='2', kind='SubGraph', name='s1', args=dict(nodes=nodes, edges=edges))]
edges = [dict(iid='__start__', oid='2'), dict(iid='2', oid='__end__')]

engine = LightEngine()
engine.start(nodes, edges)
engine.start(nodes, edges, resources)
r = engine.run('1234')
assert 'reply for You are an AI-Agent developed by LazyLLM' in r
assert '1234' in r
Expand Down Expand Up @@ -50,6 +51,21 @@ def test_engine_switch(self):
assert engine.run(2) == 6
assert engine.run(3) == 9

engine.reset()

switch = dict(id='4', kind='Switch', name='s1', args=dict(judge_on_full_input=False, nodes={
'case1': [double],
'case2': [plus1, double],
'case3': [square]
}))
engine.start([switch], edges)
assert engine.run('case1', 1) == 2
assert engine.run('case2', 1) == 4
assert engine.run('case3', 1) == 1
assert engine.run('case1', 2) == 4
assert engine.run('case2', 2) == 6
assert engine.run('case3', 3) == 9

def test_engine_ifs(self):
plus1 = dict(id='1', kind='Code', name='m1', args='def test(x: int):\n return 1 + x\n')
double = dict(id='2', kind='Code', name='m2', args='def test(x: int):\n return 2 * x\n')
Expand Down

0 comments on commit 175c12f

Please sign in to comment.