Skip to content

Commit

Permalink
fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 committed Sep 9, 2024
1 parent 71a6fbf commit 24c3871
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
22 changes: 20 additions & 2 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(self, g: lazyllm.graph, server: Node, web: Node):
self._g = lazyllm.ActionModule(g)
if server:
if server.args.get('port'): raise NotImplementedError('Port is not supported now')
self._server = lazyllm.ServerModule(g)
self._g = lazyllm.ServerModule(g)
if web:
port = self._get_port(web.args['port'])
self._web = lazyllm.WebModule(g, port=port, title=web.args['title'], audio=web.args['audio'],
Expand All @@ -115,8 +115,14 @@ def __init__(self, g: lazyllm.graph, server: Node, web: Node):
def forward(self, *args, **kw):
return self._g(*args, **kw)

# TODO(wangzhihong)
def _update(self, *, mode=None, recursive=True):
super(__class__, self)._update(mode=mode, recursive=recursive)
if hasattr(self, '_web'): self._web.start()
return self

def _get_port(self, port):
if port == '': return None
if not port: return None
elif ',' in port:
return list(int(p.strip()) for p in port.split(','))
elif '-' in port:
Expand All @@ -125,6 +131,18 @@ def _get_port(self, port):
return range(left, right)
return int(port)

@property
def api_url(self):
if isinstance(self._g, lazyllm.ServerModule):
return self._g._url
return None

@property
def web_url(self):
if hasattr(self, '_web'):
return self._web.url
return None


@NodeConstructor.register('Graph')
@NodeConstructor.register('SubGraph')
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/engine/lightengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def start(self, nodes: List[Dict] = [], edges: List[Dict] = [], resources: List[
node = Node(id=gid or str(uuid.uuid4().hex), kind='Graph',
name=name or str(uuid.uuid4().hex), args=dict(nodes=nodes, edges=edges, resources=resources))
self.graph = self.build_node(node).func
ActionModule(self.graph).start()
self.graph.start()

def update(self, nodes: List[Dict] = [], changed_nodes: List[Dict] = [],
edges: List[Dict] = [], changed_resources: List[Dict] = [],
Expand Down
22 changes: 22 additions & 0 deletions tests/basic_tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from lazyllm import LightEngine
import pytest
import time
# from gradio_client import Client
import lazyllm

class TestEngine(object):

Expand Down Expand Up @@ -98,6 +101,25 @@ def test_engine_formatter(self):
assert engine.run(1) == '1[2, 4]1'
assert engine.run(2) == '2[4, 8]4'

def test_engine_server(self):
nodes = [dict(id='1', kind='Code', name='m1', args='def test(x: int):\n return 2 * x\n')]
edges = [dict(iid='__start__', oid='1'), dict(iid='1', oid='__end__')]
resources = [dict(id='2', kind='server', name='s1', args=dict(port=None)),
# dict(id='3', kind='web', name='w1', args=dict(port=None, title='网页', history=[], audio=False))
]
engine = LightEngine()
engine.start(nodes, edges, resources, gid='graph-1')
assert engine.run(1) == 2
time.sleep(3)
# web = engine.build_node('graph-1').func._web
# client = Client(web.url, download_files=web.cach_path)
# chat_history = [['123', None]]
# ans = client.predict(False, chat_history, False, False, api_name="/_respond_stream")
# assert ans[0][-1][-1] == '123123'
# client.close()
lazyllm.launcher.cleanup()
# web.stop()


class TestEngineRAG(object):

Expand Down

0 comments on commit 24c3871

Please sign in to comment.