diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index ad0b9fecc3ef..63a2c3889ba8 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -17,7 +17,7 @@ import json import logging import re -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .. import is_torch_available from ..utils import logging as transformers_logging @@ -256,15 +256,6 @@ def __repr__(self): return toolbox_description -def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: - tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) - prompt = prompt_template.replace("<>", tool_descriptions) - if "<>" in prompt: - tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()] - prompt = prompt.replace("<>", ", ".join(tool_names)) - return prompt - - class AgentError(Exception): """Base class for other agent-related exceptions""" @@ -297,6 +288,21 @@ class AgentGenerationError(AgentError): pass +def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: + tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) + prompt = prompt_template.replace("<>", tool_descriptions) + if "<>" in prompt: + tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()] + prompt = prompt.replace("<>", ", ".join(tool_names)) + return prompt + + +def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str: + if "<>" not in prompt_template: + raise AgentError("Tag '<>' should be provided in the prompt.") + return prompt_template.replace("<>", str(authorized_imports)) + + class Agent: def __init__( self, @@ -359,8 +365,14 @@ def initialize_for_run(self, task: str, **kwargs): self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." self.state = kwargs.copy() self.system_prompt = format_prompt_with_tools( - self._toolbox, self.system_prompt_template, self.tool_description_template + self._toolbox, + self.system_prompt_template, + self.tool_description_template, ) + if hasattr(self, "authorized_imports"): + self.system_prompt = format_prompt_with_imports( + self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) + ) self.logs = [{"system_prompt": self.system_prompt, "task": self.task}] self.logger.warn("======== New task ========") self.logger.log(33, self.task) @@ -496,7 +508,7 @@ def __init__( llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, - additional_authorized_imports: List[str] = [], + additional_authorized_imports: Optional[List[str]] = None, **kwargs, ): super().__init__( @@ -515,7 +527,9 @@ def __init__( ) self.python_evaluator = evaluate_python_code - self.additional_authorized_imports = additional_authorized_imports + self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] + self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) + self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports)) def parse_code_blob(self, result: str) -> str: """ @@ -562,7 +576,13 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): return llm_output # Parse - _, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") + try: + _, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") + except Exception as e: + self.logger.debug( + f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}" + ) + code_action = llm_output try: code_action = self.parse_code_blob(code_action) @@ -579,7 +599,7 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs): code_action, available_tools, state=self.state, - authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, + authorized_imports=self.authorized_imports, ) self.logger.info(self.state["print_outputs"]) return output @@ -639,17 +659,12 @@ def provide_final_answer(self, task) -> str: def run(self, task: str, stream: bool = False, **kwargs): """ Runs the agent for the given task. - Args: task (`str`): The task to perform - Example: - ```py - from transformers.agents import ReactJsonAgent, PythonInterpreterTool - - python_interpreter = PythonInterpreterTool() - agent = ReactJsonAgent(tools=[python_interpreter]) + from transformers.agents import ReactCodeAgent + agent = ReactCodeAgent(tools=[]) agent.run("What is the result of 2 power 3.7384?") ``` """ @@ -820,7 +835,7 @@ def __init__( llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, - additional_authorized_imports: List[str] = [], + additional_authorized_imports: Optional[List[str]] = None, **kwargs, ): super().__init__( @@ -839,7 +854,9 @@ def __init__( ) self.python_evaluator = evaluate_python_code - self.additional_authorized_imports = additional_authorized_imports + self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] + self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) + self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports)) def step(self): """ @@ -871,7 +888,11 @@ def step(self): # Parse self.logger.debug("===== Extracting action =====") - rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") + try: + rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") + except Exception as e: + self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}") + rationale, raw_code_action = llm_output, llm_output try: code_action = parse_code_blob(raw_code_action) @@ -890,7 +911,7 @@ def step(self): code_action, available_tools, state=self.state, - authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, + authorized_imports=self.authorized_imports, ) information = self.state["print_outputs"] self.logger.warning("Print outputs:") diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index 7187422dc063..9adf55289d0e 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -125,12 +125,13 @@ def setup_default_tools(logger): for task_name, tool_class_name in TASK_MAPPING.items(): tool_class = getattr(tools_module, tool_class_name) + tool_instance = tool_class() default_tools[tool_class.name] = PreTool( - name=tool_class.name, - inputs=tool_class.inputs, - output_type=tool_class.output_type, + name=tool_instance.name, + inputs=tool_instance.inputs, + output_type=tool_instance.output_type, task=task_name, - description=tool_class.description, + description=tool_instance.description, repo_id=None, ) @@ -141,18 +142,25 @@ class PythonInterpreterTool(Tool): name = "python_interpreter" description = "This is a tool that evaluates python code. It can be used to perform calculations." - inputs = { - "code": { - "type": "text", - "description": ( - "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " - f"else you will get an error. This code can only import the following python libraries: {LIST_SAFE_MODULES}." - ), - } - } output_type = "text" available_tools = BASE_PYTHON_TOOLS.copy() + def __init__(self, *args, authorized_imports=None, **kwargs): + if authorized_imports is None: + authorized_imports = list(set(LIST_SAFE_MODULES)) + else: + authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) + self.inputs = { + "code": { + "type": "text", + "description": ( + "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " + f"else you will get an error. This code can only import the following python libraries: {authorized_imports}." + ), + } + } + super().__init__(*args, **kwargs) + def forward(self, code): output = str(evaluate_python_code(code, tools=self.available_tools)) return output diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index 4e5ff9970811..661df9bd24e7 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -52,6 +52,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns. You should first explain which tool you will use to perform the task and for what reason, then write the code in Python. Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so. +You can use imports in your code, but only from the following list of modules: <> Be sure to provide a 'Code:' token, else the system will be stuck in a loop. Tools: @@ -263,7 +264,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can. -To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code. +To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code. To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences. At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use. @@ -356,6 +357,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block. 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. +7. You can use imports in your code, but only from the following list of modules: <> Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. """ diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 8dc535e63c71..79e55bf6523c 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -141,15 +141,21 @@ def test_react_fails_max_iterations(self): def test_init_agent_with_different_toolsets(self): toolset_1 = [] agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) - assert len(agent.toolbox.tools) == 1 # contains only final_answer tool + assert ( + len(agent.toolbox.tools) == 1 + ) # when no tools are provided, only the final_answer tool is added by default toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm) - assert len(agent.toolbox.tools) == 2 # added final_answer tool + assert ( + len(agent.toolbox.tools) == 2 + ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer toolset_3 = Toolbox(toolset_2) agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm) - assert len(agent.toolbox.tools) == 2 # added final_answer tool + assert ( + len(agent.toolbox.tools) == 2 + ) # same as previous one, where toolset_3 is an instantiation of previous one # check that add_base_tools will not interfere with existing tools with pytest.raises(KeyError) as e: diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index dbe6c90a9ea0..51775e31e761 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -32,9 +32,19 @@ def add_two(x): class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): - self.tool = load_tool("python_interpreter") + self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"]) self.tool.setup() + def test_exact_match_input_spec(self): + inputs_spec = self.tool.inputs + expected_description = ( + "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " + "else you will get an error. This code can only import the following python libraries: " + "['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', " + "'random', 're']." + ) + self.assertEqual(inputs_spec["code"]["description"], expected_description) + def test_exact_match_arg(self): result = self.tool("(2 / 2) * 4") self.assertEqual(result, "4.0")