Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: RewooAgent does not convert input types #323

Merged
merged 6 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lazyllm/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def graph(json_file):
engine_conf = json.load(fp)

engine = LightEngine()
engine.start(engine_conf.get('nodes', []), engine_conf.get('edges', []),
engine_conf.get('resources', []))
eid = engine.start(engine_conf.get('nodes', []), engine_conf.get('edges', []),
engine_conf.get('resources', []))
while True:
query = input("query(enter 'quit' to exit): ")
if query == 'quit':
break
res = engine.run(query)
res = engine.run(eid, query)
print(f'answer: {res}')

def run(commands):
Expand Down
16 changes: 8 additions & 8 deletions lazyllm/tools/agent/rewooAgent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lazyllm.module import ModuleBase
from lazyllm import pipeline, package, LOG, globals, bind
from .toolsManager import ToolManager
from typing import List, Dict, Union
from typing import List, Dict, Union, Callable
import re

P_PROMPT_PREFIX = ("For the following tasks, make plans that can solve the problem step-by-step. "
Expand Down Expand Up @@ -31,7 +31,7 @@
"the answer directly with no extra words.\n\n")

class ReWOOAgent(ModuleBase):
def __init__(self, llm: Union[ModuleBase, None] = None, tools: List[str] = [], *,
def __init__(self, llm: Union[ModuleBase, None] = None, tools: List[Union[str, Callable]] = [], *,
plan_llm: Union[ModuleBase, None] = None, solve_llm: Union[ModuleBase, None] = None,
return_trace: bool = False):
super().__init__(return_trace=return_trace)
Expand All @@ -41,8 +41,7 @@ def __init__(self, llm: Union[ModuleBase, None] = None, tools: List[str] = [], *
assert tools, "tools cannot be empty."
self._planner = plan_llm or llm
self._solver = solve_llm or llm
self._workers = tools
self._tools_manager = ToolManager(tools, return_trace=return_trace).tools_info
self._name2tool = ToolManager(tools, return_trace=return_trace).tools_info
with pipeline() as self._agent:
self._agent.planner_pre_action = self._build_planner_prompt
self._agent.planner = self._planner
Expand All @@ -53,8 +52,8 @@ def __init__(self, llm: Union[ModuleBase, None] = None, tools: List[str] = [], *

def _build_planner_prompt(self, input: str):
prompt = P_PROMPT_PREFIX + "Tools can be one of the following:\n"
for name in self._workers:
prompt += f"{name}[search query]: {self._tools_manager[name].description}\n"
for name, tool in self._name2tool.items():
prompt += f"{name}[search query]: {tool.description}\n"
prompt += P_FEWSHOT + "\n" + P_PROMPT_SUFFIX + input + "\n"
LOG.info(f"planner prompt: {prompt}")
globals['chat_history'][self._planner._module_id] = []
Expand Down Expand Up @@ -88,8 +87,9 @@ def _get_worker_evidences(self, plans: List[str], evidence: Dict[str, str]):
for var in re.findall(r"#E\d+", tool_input):
if var in worker_evidences:
tool_input = tool_input.replace(var, "[" + worker_evidences[var] + "]")
if tool in self._workers:
worker_evidences[e] = self._tools_manager[tool](tool_input)
tool_instance = self._name2tool.get(tool)
if tool_instance:
worker_evidences[e] = tool_instance(tool_input)
else:
worker_evidences[e] = "No evidence found"

Expand Down
175 changes: 85 additions & 90 deletions lazyllm/tools/agent/toolsManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,45 @@
from typing import * # noqa F403, to import all types for compile_func(), do not remove
import time

# ---------------------------------------------------------------------------- #

def _gen_empty_func_str_from_parsed_docstring(parsed_docstring):
"""
returns a function prototype string
"""

func_name = "f" + str(int(time.time()))
s = "def " + func_name + "("
for param in parsed_docstring.params:
s += param.arg_name
if param.type_name:
s += ":" + param.type_name + ","
else:
s += ","
s += ")"

if parsed_docstring.returns and parsed_docstring.returns.type_name:
s += '->' + parsed_docstring.returns.type_name
s += ":\n pass"

return s

def _gen_func_from_str(func_str, orig_docstring, global_env=None):
if not global_env:
global_env = globals()
f = compile_func(func_str, global_env)
f.__doc__ = orig_docstring
return f

def _check_return_type_is_the_same(doc_type_hints, func_type_hints) -> None:
func_return_type = func_type_hints.get('return') if func_type_hints else None
doc_return_type = doc_type_hints.get('return') if doc_type_hints else None
if func_return_type is not None and doc_return_type is not None:
if func_return_type != doc_return_type:
raise TypeError("return info in docstring is different from that in function prototype.")

# ---------------------------------------------------------------------------- #

class ModuleTool(ModuleBase, metaclass=LazyLLMRegisterMetaClass):
def __init__(self, verbose: bool = False, return_trace: bool = True):
super().__init__(return_trace=return_trace)
Expand All @@ -21,28 +60,46 @@ def __init__(self, verbose: bool = False, return_trace: bool = True):
self._description = self.apply.__doc__\
if hasattr(self.apply, "__doc__") and self.apply.__doc__ is not None\
else (_ for _ in ()).throw(ValueError("Function must have a docstring"))
# strip space(s) and newlines before and after docstring, as RewooAgent requires
self._description = self._description.strip(' \n')

self._params_schema = self.load_function_schema(self.__class__.apply)
self._params_schema = self._load_function_schema(self.__class__.apply)

def load_function_schema(self, func: Callable) -> Type[BaseModel]:
if func.__name__ is None or func.__doc__ is None:
raise ValueError(f"Function {func} must have a name and docstring.")
self._name = func.__name__
self._description = func.__doc__
signature = inspect.signature(func)
type_hints = get_type_hints(func, globals(), locals())
def _load_function_schema(self, func: Callable) -> Type[BaseModel]:
parsed_docstring = docstring_parser.parse(self._description)
func_str_from_doc = _gen_empty_func_str_from_parsed_docstring(parsed_docstring)
func_from_doc = _gen_func_from_str(func_str_from_doc, self._description)
func_from_doc.__name__ = func.__name__
doc_type_hints = get_type_hints(func_from_doc, globals(), locals())

func_type_hints = get_type_hints(func, globals(), locals())

_check_return_type_is_the_same(doc_type_hints, func_type_hints)

self._has_var_args = False
signature = inspect.signature(func)
has_var_args = False
for name, param in signature.parameters.items():
if param.kind == inspect.Parameter.VAR_POSITIONAL or\
param.kind == inspect.Parameter.VAR_KEYWORD:
self._has_var_args = True
has_var_args = True
break

self._return_type = type_hints.get('return') if type_hints else None
if has_var_args:
# we cannot get type hints from var args, so we get them from docstring
self._type_hints = doc_type_hints
signature = inspect.signature(func_from_doc)
else:
self._type_hints = func_type_hints
# accomplish type_hints from docstring
for name, type in doc_type_hints.items():
self._type_hints.setdefault(name, type)

self._return_type = self._type_hints.get('return') if self._type_hints else None

fields = {
name: (type_hints.get(name, Any), param.default if param.default is not inspect.Parameter.empty else ...)
name: (self._type_hints.get(name, Any), param.default
if param.default is not inspect.Parameter.empty
else ...)
for name, param in signature.parameters.items()
}

Expand All @@ -62,28 +119,16 @@ def params_schema(self) -> Type[BaseModel]:

@property
def args(self) -> Dict[str, Any]:
if self._params_schema is None:
self._params_schema = self.load_function_schema(getattr(type(self), "apply"))
return self._params_schema.model_json_schema()["properties"]

@property
def required_args(self) -> Set[str]:
if self._params_schema is None:
self._params_schema = self.load_function_schema(getattr(type(self), "apply"))
return set(self._params_schema.model_json_schema()["required"])

def get_params_schema(self) -> [BaseModel]:
if self._params_schema is None:
self._params_schema = self.load_function_schema(getattr(type(self), "apply"))
return self._params_schema

def apply(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError("Implement apply function in subclass")

def _validate_input(self, tool_input: Dict[str, Any]) -> Dict[str, Any]:
if self._has_var_args:
return tool_input

input_params = self._params_schema
if isinstance(tool_input, dict):
if input_params is not None:
Expand All @@ -94,8 +139,15 @@ def _validate_input(self, tool_input: Dict[str, Any]) -> Dict[str, Any]:
if input_params is not None:
key = next(iter(input_params.model_fields.keys()))
input_params.model_validate({key: tool_input})
arg_type = self._type_hints.get(key)
if arg_type:
return {key: arg_type(tool_input)}
return {key: tool_input}
return tool_input

if len(self._type_hints) != 1:
return tool_input
arg_type = self._type_hints.values()[0]
return arg_type(tool_input)
else:
raise TypeError(f"tool_input {tool_input} only supports dict and str.")

Expand All @@ -111,7 +163,10 @@ def validate_parameters(self, arguments: Dict[str, Any]) -> bool:

def forward(self, tool_input: Union[str, Dict[str, Any]], verbose: bool = False) -> Any:
val_input = self._validate_input(tool_input)
ret = self.apply(**val_input)
if isinstance(val_input, dict):
ret = self.apply(**val_input)
else:
ret = self.apply(val_input)
if verbose or self._verbose:
lazyllm.LOG.debug(f"The output of tool {self.name} is {ret}")

Expand Down Expand Up @@ -162,55 +217,12 @@ def _validate_tool(self, tool_name: str, tool_arguments: Dict[str, Any]):
LOG.error(f'cannot find tool named [{tool_name}]')
return False

# don't check parameters if this function contains '*args' or '**kwargs'
if tool._has_var_args:
return True

return tool.validate_parameters(tool_arguments)

def _format_tools(self):
if isinstance(self._tools, List):
self._tool_call = {tool.name: tool for tool in self._tools}

@staticmethod
def _gen_empty_func_str_from_parsed_docstring(parsed_docstring):
"""
returns a function prototype string
"""

func_name = "f" + str(int(time.time()))
s = "def " + func_name + "("
for param in parsed_docstring.params:
s += param.arg_name
if param.type_name:
s += ":" + param.type_name + ","
else:
s += ","
s += ")"

if parsed_docstring.returns and parsed_docstring.returns.type_name:
s += '->' + parsed_docstring.returns.type_name
s += ":\n pass"

return s

@staticmethod
def _gen_func_from_str(func_str, orig_docstring, global_env=None):
if not global_env:
global_env = globals()
f = compile_func(func_str, global_env)
f.__doc__ = orig_docstring
return f

@staticmethod
def _gen_wrapped_moduletool(func):
if "tmp_tool" not in LazyLLMRegisterMetaClass.all_clses:
register.new_group('tmp_tool')
register('tmp_tool')(func)
wrapped_module = getattr(lazyllm.tmp_tool, func.__name__)()
lazyllm.tmp_tool.remove(func.__name__)
return wrapped_module

@staticmethod
def _gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring):
"""
Expand Down Expand Up @@ -249,16 +261,10 @@ def _gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring):
if desc:
args[k].update({"description": desc})
else:
raise ValueError(f"The actual input parameter {k} is not found "
"in the parameter description.")
raise ValueError(f"The actual input parameter '{k}' is not found "
f"in the parameter description of tool '{tool.name}'.")
return args

@staticmethod
def _check_return_info_is_the_same(func, tool) -> bool:
type_hints = get_type_hints(func, globals(), locals())
return_type = type_hints.get('return') if type_hints else None
return return_type == tool._return_type

def _transform_to_openai_function(self):
if not isinstance(self._tools, List):
raise TypeError(f"The tools type should be List instead of {type(self._tools)}")
Expand All @@ -267,19 +273,8 @@ def _transform_to_openai_function(self):
for tool in self._tools:
try:
parsed_docstring = docstring_parser.parse(tool.description)
func_str_from_doc = self._gen_empty_func_str_from_parsed_docstring(parsed_docstring)
func_from_doc = self._gen_func_from_str(func_str_from_doc, tool.description)

if tool._has_var_args:
tmp_tool = self._gen_wrapped_moduletool(func_from_doc)
args = self._gen_args_info_from_moduletool_and_docstring(tmp_tool, parsed_docstring)
required_arg_list = tmp_tool.get_params_schema().model_json_schema().get("required", [])
else:
args = self._gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring)
required_arg_list = tool.get_params_schema().model_json_schema().get("required", [])
if not self._check_return_info_is_the_same(func_from_doc, tool):
raise ValueError("return info in docstring is different from that in function prototype.")

args = self._gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring)
required_arg_list = tool.params_schema.model_json_schema().get("required", [])
func = {
"type": "function",
"function": {
Expand Down
Loading
Loading