Skip to content

Commit

Permalink
re impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Oct 25, 2024
1 parent 3d8f2f5 commit 8e3e953
Showing 1 changed file with 88 additions and 74 deletions.
162 changes: 88 additions & 74 deletions lazyllm/tools/agent/toolsManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,54 +41,12 @@ def _gen_func_from_str(func_str, orig_docstring, global_env=None):
f.__doc__ = orig_docstring
return f

def _gen_wrapped_moduletool(func):
if "tmp_tool" not in LazyLLMRegisterMetaClass.all_clses:
register.new_group('tmp_tool')
register('tmp_tool')(func)
wrapped_module = getattr(lazyllm.tmp_tool, func.__name__)()
lazyllm.tmp_tool.remove(func.__name__)
return wrapped_module

def _gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring):
"""
returns a dict of param names containing at least
1. `type`
2. `description` of params
for example:
args = {
"foo": {
"enum": ["baz", "bar"],
"type": "string",
"description": "a string",
},
"bar": {
"type": "integer",
"description": "an integer",
}
}
"""
tool_args = tool.args
assert len(tool_args) == len(parsed_docstring.params), ("The parameter description and the actual "
"number of input parameters are inconsistent.")

args_description = {}
for param in parsed_docstring.params:
args_description[param.arg_name] = param.description

args = {}
for k, v in tool_args.items():
val = copy.deepcopy(v)
val.pop("title", None)
val.pop("default", None)
args[k] = val if val else {"type": "string"}
desc = args_description.get(k, None)
if desc:
args[k].update({"description": desc})
else:
raise ValueError(f"The actual input parameter {k} is not found "
"in the parameter description.")
return args
def _check_return_info_is_the_same(doc_type_hints, func_type_hints) -> None:
func_return_type = func_type_hints.get('return') if func_type_hints else None
doc_return_type = doc_type_hints.get('return') if doc_type_hints else None
if func_return_type is not None and doc_return_type is not None:
if func_return_type != doc_return_type:
raise TypeError("return info in docstring is different from that in function prototype.")

# ---------------------------------------------------------------------------- #

Expand All @@ -103,19 +61,37 @@ def __init__(self, verbose: bool = False, return_trace: bool = True):
if hasattr(self.apply, "__doc__") and self.apply.__doc__ is not None\
else (_ for _ in ()).throw(ValueError("Function must have a docstring"))

self._params_schema = self._load_function_schema(self.__class__.apply)

def _load_function_schema(self, func: Callable) -> Type[BaseModel]:
parsed_docstring = docstring_parser.parse(self._description)
func_str_from_doc = _gen_empty_func_str_from_parsed_docstring(parsed_docstring)
func_from_doc = _gen_func_from_str(func_str_from_doc, self._description)
func_from_doc.__name__ = self._name
self._params_schema = self._load_function_schema(self.__class__.apply, func_from_doc)

def _load_function_schema(self, orig_func: Callable, func_from_doc: Callable) -> Type[BaseModel]:
orig_type_hints = get_type_hints(orig_func, globals(), locals())
doc_type_hints = get_type_hints(func_from_doc, globals(), locals())
self._type_hints = doc_type_hints | orig_type_hints

func_type_hints = get_type_hints(func, globals(), locals())

_check_return_info_is_the_same(doc_type_hints, func_type_hints)

signature = inspect.signature(func)
has_var_args = False
for name, param in signature.parameters.items():
if param.kind == inspect.Parameter.VAR_POSITIONAL or\
param.kind == inspect.Parameter.VAR_KEYWORD:
has_var_args = True
break

if has_var_args:
# we cannot get type hints from var args, so we get them from docstring
self._type_hints = doc_type_hints
else:
self._type_hints = func_type_hints
# accomplish type_hints from docstring
for name, type in doc_type_hints.items():
self._type_hints.setdefault(name, type)

self._return_type = self._type_hints.get('return') if self._type_hints else None

signature = inspect.signature(func_from_doc)
fields = {
name: (self._type_hints.get(name, Any), param.default
if param.default is not inspect.Parameter.empty
Expand Down Expand Up @@ -161,16 +137,16 @@ def _validate_input(self, tool_input: Dict[str, Any]) -> Dict[str, Any]:
elif isinstance(tool_input, str):
if input_params is not None:
key = next(iter(input_params.model_fields.keys()))
key_type = self._type_hints.get(key)
if not key_type or not isinstance(key_type, type):
raise TypeError(f'"{key_type} is not a type"')
input_params.model_validate({key: tool_input})
return {key: key_type(tool_input)}
arg_type = self._type_hints.get(key)
if arg_type:
return {key: arg_type(tool_input)}
return {key: tool_input}

types = self._type_hints.values()
if len(types) == 0:
if len(self._type_hints) != 1:
return tool_input
return types[0](tool_input)
arg_type = self._type_hints.values()[0]
return arg_type(tool_input)
else:
raise TypeError(f"tool_input {tool_input} only supports dict and str.")

Expand Down Expand Up @@ -247,10 +223,55 @@ def _format_tools(self):
self._tool_call = {tool.name: tool for tool in self._tools}

@staticmethod
def _check_return_info_is_the_same(func, tool) -> bool:
type_hints = get_type_hints(func, globals(), locals())
return_type = type_hints.get('return') if type_hints else None
return return_type == tool._return_type
def _gen_wrapped_moduletool(func):
if "tmp_tool" not in LazyLLMRegisterMetaClass.all_clses:
register.new_group('tmp_tool')
register('tmp_tool')(func)
wrapped_module = getattr(lazyllm.tmp_tool, func.__name__)()
lazyllm.tmp_tool.remove(func.__name__)
return wrapped_module

@staticmethod
def _gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring):
"""
returns a dict of param names containing at least
1. `type`
2. `description` of params
for example:
args = {
"foo": {
"enum": ["baz", "bar"],
"type": "string",
"description": "a string",
},
"bar": {
"type": "integer",
"description": "an integer",
}
}
"""
tool_args = tool.args
assert len(tool_args) == len(parsed_docstring.params), ("The parameter description and the actual "
"number of input parameters are inconsistent.")

args_description = {}
for param in parsed_docstring.params:
args_description[param.arg_name] = param.description

args = {}
for k, v in tool_args.items():
val = copy.deepcopy(v)
val.pop("title", None)
val.pop("default", None)
args[k] = val if val else {"type": "string"}
desc = args_description.get(k, None)
if desc:
args[k].update({"description": desc})
else:
raise ValueError(f"The actual input parameter {k} is not found "
"in the parameter description.")
return args

def _transform_to_openai_function(self):
if not isinstance(self._tools, List):
Expand All @@ -260,15 +281,8 @@ def _transform_to_openai_function(self):
for tool in self._tools:
try:
parsed_docstring = docstring_parser.parse(tool.description)
args = _gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring)
args = self._gen_args_info_from_moduletool_and_docstring(tool, parsed_docstring)
required_arg_list = tool.get_params_schema().model_json_schema().get("required", [])

# check return values
func_str_from_doc = _gen_empty_func_str_from_parsed_docstring(parsed_docstring)
func_from_doc = _gen_func_from_str(func_str_from_doc, tool.description)
if not self._check_return_info_is_the_same(func_from_doc, tool):
raise ValueError("return info in docstring is different from that in function prototype.")

func = {
"type": "function",
"function": {
Expand Down

0 comments on commit 8e3e953

Please sign in to comment.