Skip to content

Commit

Permalink
refactor and fix token usage key missing
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyongchao committed Nov 13, 2024
1 parent 41d0bd5 commit d5c3dc7
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 18 deletions.
16 changes: 2 additions & 14 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def _impl(nid, recursive):
yield id
if recursive:
yield from self.subnodes(id, True)

return list(_impl(nodeid, recursive))


Expand All @@ -105,7 +104,6 @@ def impl(f):
for name in names:
cls.builder_methods[name] = (f, subitems)
return f

return impl

# build node recursively
Expand Down Expand Up @@ -171,11 +169,7 @@ def get_args(cls, key, value, builder_key=None):
def _process_hook(self, node, module):
if not node.enable_data_reflow:
return
if isinstance(module, lazyllm.ModuleBase):
NodeMetaHook.MODULEID_TO_WIDGETID[module._module_id] = node.id
elif isinstance(module, lazyllm.LazyLLMFlowsBase):
NodeMetaHook.MODULEID_TO_WIDGETID[module._flow_id] = node.id
else:
if not isinstance(module, (lazyllm.ModuleBase, lazyllm.LazyLLMFlowsBase)):
return
node.func.register_hook(NodeMetaHook(node.func, Engine.REPORT_URL, node.id))

Expand Down Expand Up @@ -487,27 +481,22 @@ def _parse_py_data_by_formatter(self, data):
else:
raise TypeError('type should be one of sum/stack/to_dict/join')


@NodeConstructor.register('JoinFormatter')
def make_join_formatter(type='sum', names=None, symbol=None):
if type == 'file':
return make_formatter('file', rule='merge')
return JoinFormatter(type, names=names, symbol=symbol)


@NodeConstructor.register('Formatter')
def make_formatter(ftype, rule):
return getattr(lazyllm.formatter, ftype)(formatter=rule)


def return_a_wrapper_func(func):
@functools.wraps(func)
def wrapper_func(*args, **kwargs):
return func(*args, **kwargs)

return wrapper_func


def _get_tools(tools):
callable_list = []
for rid in tools: # `tools` is a list of ids in engine's resources
Expand All @@ -520,12 +509,10 @@ def _get_tools(tools):
callable_list.append(wrapper_func)
return callable_list


@NodeConstructor.register('ToolsForLLM')
def make_tools_for_llm(tools: List[str]):
return lazyllm.tools.ToolManager(_get_tools(tools))


@NodeConstructor.register('FunctionCall')
def make_fc(llm: str, tools: List[str], algorithm: Optional[str] = None):
f = (
Expand Down Expand Up @@ -566,6 +553,7 @@ def make_http_tool(


class VQA(lazyllm.Module):

def __init__(
self,
base_model: Union[str, lazyllm.TrainableModule],
Expand Down
5 changes: 1 addition & 4 deletions lazyllm/engine/node_meta_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class MetaKeys:


class NodeMetaHook(LazyLLMHook):
URL = ""
MODULEID_TO_WIDGETID = {}

def __init__(self, obj, url, front_id):
if isinstance(obj, lazyllm.ModuleBase):
Expand Down Expand Up @@ -56,15 +54,14 @@ def post_hook(self, output):
self._meta_info[MetaKeys.OUTPUT] = str(output)

if self._uniqueid in globals["usage"]:
self._meta_info.update(globals["usage"])
self._meta_info.update(globals["usage"][self._uniqueid])
self._meta_info[MetaKeys.ID] = self._front_id
self._meta_info[MetaKeys.TIMECOST] = time.time() - self._meta_info[MetaKeys.TIMECOST]

def report(self):
headers = {"Content-Type": "application/json; charset=utf-8"}
json_data = json.dumps(self._meta_info, ensure_ascii=False)
try:
lazyllm.LOG.info(f"meta_info: {self._meta_info}")
requests.post(self._url, data=json_data, headers=headers)
except Exception as e:
lazyllm.LOG.warning(f"Error sending collected data: {e}. URL: {self._url}")
2 changes: 2 additions & 0 deletions tests/basic_tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
@app.post("/{route}")
async def receive_json(data: dict):
print("Received json data:", data)
assert "prompt_tokens" in data
return JSONResponse(content=data)
Expand Down Expand Up @@ -62,6 +63,7 @@ def read_stdout(process):

@classmethod
def tearDownClass(cls):
time.sleep(3)
cls.fastapi_process.terminate()
cls.fastapi_process.wait()

Expand Down

0 comments on commit d5c3dc7

Please sign in to comment.