Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to declare imports for code agent #31355

Merged
73 changes: 47 additions & 26 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
import logging
import re
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from .. import is_torch_available
from ..utils import logging as transformers_logging
Expand Down Expand Up @@ -256,15 +256,6 @@ def __repr__(self):
return toolbox_description


def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
if "<<tool_names>>" in prompt:
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
return prompt


class AgentError(Exception):
"""Base class for other agent-related exceptions"""

Expand Down Expand Up @@ -297,6 +288,21 @@ class AgentGenerationError(AgentError):
pass


def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
if "<<tool_names>>" in prompt:
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
return prompt


def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
if "<<authorized_imports>>" not in prompt_template:
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))


class Agent:
def __init__(
self,
Expand Down Expand Up @@ -359,8 +365,14 @@ def initialize_for_run(self, task: str, **kwargs):
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.system_prompt = format_prompt_with_tools(
self._toolbox, self.system_prompt_template, self.tool_description_template
self._toolbox,
self.system_prompt_template,
self.tool_description_template,
)
if hasattr(self, "authorized_imports"):
self.system_prompt = format_prompt_with_imports(
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
)
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
self.logger.warn("======== New task ========")
self.logger.log(33, self.task)
Expand Down Expand Up @@ -496,7 +508,7 @@ def __init__(
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
additional_authorized_imports: List[str] = [],
additional_authorized_imports: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
Expand All @@ -515,7 +527,9 @@ def __init__(
)

self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))

def parse_code_blob(self, result: str) -> str:
"""
Expand Down Expand Up @@ -562,7 +576,13 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs):
return llm_output

# Parse
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
try:
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
self.logger.debug(
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
)
code_action = llm_output

try:
code_action = self.parse_code_blob(code_action)
Expand All @@ -579,7 +599,7 @@ def run(self, task: str, return_generated_code: bool = False, **kwargs):
code_action,
available_tools,
state=self.state,
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
authorized_imports=self.authorized_imports,
)
self.logger.info(self.state["print_outputs"])
return output
Expand Down Expand Up @@ -639,17 +659,12 @@ def provide_final_answer(self, task) -> str:
def run(self, task: str, stream: bool = False, **kwargs):
"""
Runs the agent for the given task.

Args:
task (`str`): The task to perform

Example:

```py
from transformers.agents import ReactJsonAgent, PythonInterpreterTool

python_interpreter = PythonInterpreterTool()
agent = ReactJsonAgent(tools=[python_interpreter])
from transformers.agents import ReactCodeAgent
agent = ReactCodeAgent(tools=[])
agent.run("What is the result of 2 power 3.7384?")
```
"""
Expand Down Expand Up @@ -820,7 +835,7 @@ def __init__(
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
additional_authorized_imports: List[str] = [],
additional_authorized_imports: Optional[List[str]] = None,
JasonZhu1313 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
super().__init__(
Expand All @@ -839,7 +854,9 @@ def __init__(
)

self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))

def step(self):
"""
Expand Down Expand Up @@ -871,7 +888,11 @@ def step(self):

# Parse
self.logger.debug("===== Extracting action =====")
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
try:
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
rationale, raw_code_action = llm_output, llm_output

try:
code_action = parse_code_blob(raw_code_action)
Expand All @@ -890,7 +911,7 @@ def step(self):
code_action,
available_tools,
state=self.state,
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
authorized_imports=self.authorized_imports,
)
information = self.state["print_outputs"]
self.logger.warning("Print outputs:")
Expand Down
34 changes: 21 additions & 13 deletions src/transformers/agents/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ def setup_default_tools(logger):

for task_name, tool_class_name in TASK_MAPPING.items():
tool_class = getattr(tools_module, tool_class_name)
tool_instance = tool_class()
default_tools[tool_class.name] = PreTool(
name=tool_class.name,
inputs=tool_class.inputs,
output_type=tool_class.output_type,
name=tool_instance.name,
inputs=tool_instance.inputs,
output_type=tool_instance.output_type,
task=task_name,
description=tool_class.description,
description=tool_instance.description,
repo_id=None,
)

Expand All @@ -141,18 +142,25 @@ class PythonInterpreterTool(Tool):
name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations."

inputs = {
"code": {
"type": "text",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {LIST_SAFE_MODULES}."
),
}
}
output_type = "text"
available_tools = BASE_PYTHON_TOOLS.copy()

def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None:
authorized_imports = list(set(LIST_SAFE_MODULES))
else:
authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
self.inputs = {
"code": {
"type": "text",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
),
}
}
super().__init__(*args, **kwargs)

def forward(self, code):
output = str(evaluate_python_code(code, tools=self.available_tools))
return output
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/agents/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
Be sure to provide a 'Code:' token, else the system will be stuck in a loop.

Tools:
Expand Down Expand Up @@ -262,7 +263,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):


DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You will be given a task to solve as best you can.
To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code.
To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.

At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
Expand Down Expand Up @@ -355,6 +356,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>

Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
"""
12 changes: 9 additions & 3 deletions tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,21 @@ def test_react_fails_max_iterations(self):
def test_init_agent_with_different_toolsets(self):
toolset_1 = []
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 1 # contains only final_answer tool
assert (
len(agent.toolbox.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default

toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 2 # added final_answer tool
assert (
len(agent.toolbox.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer

toolset_3 = Toolbox(toolset_2)
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 2 # added final_answer tool
assert (
len(agent.toolbox.tools) == 2
) # same as previous one, where toolset_3 is an instantiation of previous one

# check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e:
Expand Down
12 changes: 11 additions & 1 deletion tests/agents/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,19 @@ def add_two(x):

class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("python_interpreter")
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
self.tool.setup()

def test_exact_match_input_spec(self):
inputs_spec = self.tool.inputs
expected_description = (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
"else you will get an error. This code can only import the following python libraries: "
"['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', "
"'random', 're']."
)
self.assertEqual(inputs_spec["code"]["description"], expected_description)

def test_exact_match_arg(self):
result = self.tool("(2 / 2) * 4")
self.assertEqual(result, "4.0")
Expand Down
Loading