Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #11737 issue (extra_tools option of create_pandas_dataframe_agent is not working) #13203

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def _get_multi_prompt(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
extra_tools: Sequence[BaseTool] = (),
) -> Tuple[BasePromptTemplate, List[BaseTool]]:
num_dfs = len(dfs)
if suffix is not None:
suffix_to_use = suffix
Expand All @@ -55,12 +56,13 @@ def _get_multi_prompt(
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = [PythonAstREPLTool(locals=df_locals)]

tools = [PythonAstREPLTool(locals=df_locals)] + list(extra_tools)
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
tools,
prefix=prefix,
suffix=suffix_to_use,
input_variables=input_variables,
)

partial_prompt = prompt.partial()
if "dfs_head" in input_variables:
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
Expand All @@ -77,7 +79,8 @@ def _get_single_prompt(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
extra_tools: Sequence[BaseTool] = (),
) -> Tuple[BasePromptTemplate, List[BaseTool]]:
if suffix is not None:
suffix_to_use = suffix
include_df_head = True
Expand All @@ -96,10 +99,13 @@ def _get_single_prompt(
if prefix is None:
prefix = PREFIX

tools = [PythonAstREPLTool(locals={"df": df})]
tools = [PythonAstREPLTool(locals={"df": df})] + list(extra_tools)

prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
tools,
prefix=prefix,
suffix=suffix_to_use,
input_variables=input_variables,
)

partial_prompt = prompt.partial()
Expand All @@ -117,7 +123,8 @@ def _get_prompt_and_tools(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
extra_tools: Sequence[BaseTool] = (),
) -> Tuple[BasePromptTemplate, List[BaseTool]]:
try:
import pandas as pd

Expand All @@ -141,6 +148,7 @@ def _get_prompt_and_tools(
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
extra_tools=extra_tools,
)
else:
if not isinstance(df, pd.DataFrame):
Expand All @@ -152,6 +160,7 @@ def _get_prompt_and_tools(
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
extra_tools=extra_tools,
)


Expand Down Expand Up @@ -287,6 +296,7 @@ def create_pandas_dataframe_agent(
) -> AgentExecutor:
"""Construct a pandas agent from an LLM and dataframe."""
agent: BaseSingleActionAgent
base_tools: Sequence[BaseTool]
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt, base_tools = _get_prompt_and_tools(
df,
Expand All @@ -295,8 +305,9 @@ def create_pandas_dataframe_agent(
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
extra_tools=extra_tools,
)
tools = base_tools + list(extra_tools)
tools = base_tools
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
Expand All @@ -318,7 +329,7 @@ def create_pandas_dataframe_agent(
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
tools = base_tools + list(extra_tools)
tools = list(base_tools) + list(extra_tools)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
Expand Down
Loading