Skip to content

Commit

Permalink
Update RAGAgentLlama and ReActLlama (#843)
Browse files Browse the repository at this point in the history
* use ChatOpenAI to interface with TGI-gaudi

Signed-off-by: minmin-intel <[email protected]>

* update tools for unit test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update readme and test

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test script

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* consolidate chat model setup

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update tests and readme

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert react llama output parser

---------

Signed-off-by: minmin-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
minmin-intel and pre-commit-ci[bot] authored Nov 1, 2024
1 parent 9f68bd3 commit c8e3639
Show file tree
Hide file tree
Showing 17 changed files with 427 additions and 303 deletions.
37 changes: 23 additions & 14 deletions comps/agent/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,28 @@ We currently support the following types of agents:
1. ReAct: use `react_langchain` or `react_langgraph` or `react_llama` as strategy. First introduced in this seminal [paper](https://arxiv.org/abs/2210.03629). The ReAct agent engages in "reason-act-observe" cycles to solve problems. Please refer to this [doc](https://python.langchain.com/v0.2/docs/how_to/migrate_agent/) to understand the differences between the langchain and langgraph versions of react agents. See table below to understand the validated LLMs for each react strategy.
2. RAG agent: use `rag_agent` or `rag_agent_llama` strategy. This agent is specifically designed for improving RAG performance. It has the capability to rephrase query, check relevancy of retrieved context, and iterate if context is not relevant. See table below to understand the validated LLMs for each rag agent strategy.
3. Plan and execute: `plan_execute` strategy. This type of agent first makes a step-by-step plan given a user request, and then execute the plan sequentially (or in parallel, to be implemented in future). If the execution results can solve the problem, then the agent will output an answer; otherwise, it will replan and execute again.
For advanced developers who want to implement their own agent strategies, please refer to [Section 5](#5-customize-agent-strategy) below.

**Note**:

1. Due to the limitations in support for tool calling by TGI and vllm, we have developed subcategories of agent strategies (`rag_agent_llama` and `react_llama`) specifically designed for open-source LLMs served with TGI and vllm.
2. For advanced developers who want to implement their own agent strategies, please refer to [Section 5](#5-customize-agent-strategy) below.

### 1.2 LLM engine

Agents use LLM for reasoning and planning. We support 2 options of LLM engine:
Agents use LLM for reasoning and planning. We support 3 options of LLM engine:

1. Open-source LLMs served with TGI-gaudi. To use open-source llms, follow the instructions in [Section 2](#222-start-microservices) below. Note: we recommend using state-of-the-art LLMs, such as llama3.1-70B-instruct, to get higher success rate.
2. OpenAI LLMs via API calls. To use OpenAI llms, specify `llm_engine=openai` and `export OPENAI_API_KEY=<your-openai-key>`
1. Open-source LLMs served with TGI. Follow the instructions in [Section 2.2.1](#221-start-agent-microservices-with-tgi).
2. Open-source LLMs served with vllm. Follow the instructions in [Section 2.2.2](#222-start-agent-microservices-with-vllm).
3. OpenAI LLMs via API calls. To use OpenAI llms, specify `llm_engine=openai` and `export OPENAI_API_KEY=<your-openai-key>`

| Agent type | `strategy` arg | Validated LLMs | Notes |
| ---------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ReAct | `react_langchain` | GPT-4o-mini, [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) | Only allows tools with one input variable |
| ReAct | `react_langgraph` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)-on-vllm, | Currently does not work for open-source LLMs served with TGI-Gaudi, [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)-on-vllm is not synced from vllm upstream to gaudi repo yet. |
| ReAct | `react_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) | Recommended for open-source LLMs served with TGI-Gaudi |
| RAG agent | `rag_agent` | GPT-4o-mini | Currently does not work for open-source LLMs served with TGI-Gaudi |
| RAG agent | `rag_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) | Recommended for open-source LLMs served with TGI-Gaudi, only allows 1 tool with input variable to be "query" |
| Plan and execute | `plan_execute` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)-on-vllm, [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)-on-vllm | |
| Agent type | `strategy` arg | Validated LLMs (serving SW) | Notes |
| ---------------- | ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- |
| ReAct | `react_langchain` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Only allows tools with one input variable |
| ReAct | `react_langgraph` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), | if using vllm, need to specify `--enable-auto-tool-choice --tool-call-parser ${model_parser}`, refer to vllm docs for more info |
| ReAct | `react_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Recommended for open-source LLMs |
| RAG agent | `rag_agent` | GPT-4o-mini | |
| RAG agent | `rag_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Recommended for open-source LLMs, only allows 1 tool with input variable to be "query" |
| Plan and execute | `plan_execute` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) (vllm-gaudi) | |

### 1.3 Tools

Expand All @@ -36,11 +41,15 @@ The tools are registered with a yaml file. We support the following types of too
1. Endpoint: user to provide url
2. User-defined python functions. This is usually used to wrap endpoints with request post or simple pre/post-processing.
3. Langchain tool modules.
Examples of how to register tools can be found in [Section 4](#-4-provide-your-own-tools) below.

Examples of how to register tools can be found in [Section 4](#-4-provide-your-own-tools) below.

### 1.4 Agent APIs

Currently we have implemented OpenAI chat completion compatible API for agents. We are working to support OpenAI assistants APIs.
1. OpenAI compatible chat completions API
2. OpenAI compatible assistants APIs.

**Note**: not all keywords are supported yet.

## 🚀2. Start Agent Microservice

Expand Down
4 changes: 2 additions & 2 deletions comps/agent/langchain/src/strategy/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from uuid import uuid4

from ..tools import get_tools_descriptions
from ..utils import adapt_custom_prompt, setup_llm
from ..utils import adapt_custom_prompt, setup_chat_model


class BaseAgent:
def __init__(self, args, local_vars=None, **kwargs) -> None:
self.llm_endpoint = setup_llm(args)
self.llm = setup_chat_model(args)
self.tools_descriptions = get_tools_descriptions(args.tools)
self.app = None
self.memory = None
Expand Down
60 changes: 27 additions & 33 deletions comps/agent/langchain/src/strategy/planexec/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pydantic import BaseModel, Field

from ...global_var import threads_global_kv
from ...utils import has_multi_tool_inputs, tool_renderer, wrap_chat
from ...utils import has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent
from .prompt import (
answer_check_prompt,
Expand All @@ -25,8 +25,6 @@
replanner_prompt,
)

# Define protocol


class PlanExecute(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
Expand Down Expand Up @@ -56,16 +54,14 @@ class PlanStepChecker:
str: A decision for whether we should use this plan or not
"""

def __init__(self, llm_endpoint, model_id=None, is_vllm=False):
def __init__(self, llm, is_vllm=False):
class grade(BaseModel):
binary_score: str = Field(description="executable score 'yes' or 'no'")

if is_vllm:
llm = wrap_chat(llm_endpoint, model_id).bind_tools(
[grade], tool_choice={"function": {"name": grade.__name__}}
)
llm = llm.bind_tools([grade], tool_choice={"function": {"name": grade.__name__}})
else:
llm = wrap_chat(llm_endpoint, model_id).bind_tools([grade])
llm = llm.bind_tools([grade])

output_parser = PydanticToolsParser(tools=[grade], first_tool_only=True)
self.chain = plan_check_prompt | llm | output_parser
Expand All @@ -74,7 +70,7 @@ def __call__(self, state):
# print("---CALL PlanStepChecker---")
scored_result = self.chain.invoke(state)
score = scored_result.binary_score
# print(f"Task is {state['context']}, Greade of relevance to question is {score}")
print(f"Task is {state['context']}, Score is {score}")
if score.startswith("yes"):
return True
else:
Expand All @@ -83,13 +79,11 @@ def __call__(self, state):

# Define workflow Node
class Planner:
def __init__(self, llm_endpoint, model_id=None, plan_checker=None, is_vllm=False):
def __init__(self, llm, plan_checker=None, is_vllm=False):
if is_vllm:
llm = wrap_chat(llm_endpoint, model_id).bind_tools(
[Plan], tool_choice={"function": {"name": Plan.__name__}}
)
llm = llm.bind_tools([Plan], tool_choice={"function": {"name": Plan.__name__}})
else:
llm = wrap_chat(llm_endpoint, model_id).bind_tools([Plan])
llm = llm.bind_tools([Plan])
output_parser = PydanticToolsParser(tools=[Plan], first_tool_only=True)
self.llm = planner_prompt | llm | output_parser
self.plan_checker = plan_checker
Expand All @@ -103,6 +97,7 @@ def __call__(self, state):
while not success:
try:
plan = self.llm.invoke({"messages": [("user", state["messages"][-1].content)]})
print("Generated plan: ", plan)
success = True
except OutputParserException as e:
pass
Expand All @@ -116,17 +111,17 @@ def __call__(self, state):

if len(steps) == 0:
success = False

print("Steps: ", steps)
return {"input": input, "plan": steps}


class Executor:
def __init__(self, llm_endpoint, model_id=None, tools=[]):
def __init__(self, llm, tools=[]):
prompt = hwchase17_react_prompt
if has_multi_tool_inputs(tools):
raise ValueError("Only supports single input tools when using strategy == react")
else:
agent_chain = create_react_agent(llm_endpoint, tools, prompt, tools_renderer=tool_renderer)
agent_chain = create_react_agent(llm, tools, prompt, tools_renderer=tool_renderer)
self.agent_executor = AgentExecutor(
agent=agent_chain, tools=tools, handle_parsing_errors=True, max_iterations=50
)
Expand All @@ -148,20 +143,19 @@ def __call__(self, state):
agent_response = self.agent_executor.invoke({"input": task_formatted})
output = agent_response["output"]
success = True
print(f"Task is {step}, Response is {output}")
out_state.append(f"Task is {step}, Response is {output}")
return {
"past_steps": out_state,
}


class AnswerMaker:
def __init__(self, llm_endpoint, model_id=None, is_vllm=False):
def __init__(self, llm, is_vllm=False):
if is_vllm:
llm = wrap_chat(llm_endpoint, model_id).bind_tools(
[Response], tool_choice={"function": {"name": Response.__name__}}
)
llm = llm.bind_tools([Response], tool_choice={"function": {"name": Response.__name__}})
else:
llm = wrap_chat(llm_endpoint, model_id).bind_tools([Response])
llm = llm.bind_tools([Response])
output_parser = PydanticToolsParser(tools=[Response], first_tool_only=True)
self.llm = answer_make_prompt | llm | output_parser

Expand All @@ -172,6 +166,7 @@ def __call__(self, state):
while not success:
try:
output = self.llm.invoke(state)
print("Generated response: ", output.response)
success = True
except OutputParserException as e:
pass
Expand All @@ -188,16 +183,14 @@ class FinalAnswerChecker:
str: A decision for whether we should use this plan or not
"""

def __init__(self, llm_endpoint, model_id=None, is_vllm=False):
def __init__(self, llm, is_vllm=False):
class grade(BaseModel):
binary_score: str = Field(description="executable score 'yes' or 'no'")

if is_vllm:
llm = wrap_chat(llm_endpoint, model_id).bind_tools(
[grade], tool_choice={"function": {"name": grade.__name__}}
)
llm = llm.bind_tools([grade], tool_choice={"function": {"name": grade.__name__}})
else:
llm = wrap_chat(llm_endpoint, model_id).bind_tools([grade])
llm = llm.bind_tools([grade])
output_parser = PydanticToolsParser(tools=[grade], first_tool_only=True)
self.chain = answer_check_prompt | llm | output_parser

Expand All @@ -213,8 +206,8 @@ def __call__(self, state):


class Replanner:
def __init__(self, llm_endpoint, model_id=None, answer_checker=None):
llm = wrap_chat(llm_endpoint, model_id).bind_tools([Plan])
def __init__(self, llm, answer_checker=None):
llm = llm.bind_tools([Plan])
output_parser = PydanticToolsParser(tools=[Plan], first_tool_only=True)
self.llm = replanner_prompt | llm | output_parser
self.answer_checker = answer_checker
Expand All @@ -227,6 +220,7 @@ def __call__(self, state):
try:
output = self.llm.invoke(state)
success = True
print("Replan: ", output)
except OutputParserException as e:
pass
except Exception as e:
Expand All @@ -240,11 +234,11 @@ def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)

# Define Node
plan_checker = PlanStepChecker(self.llm_endpoint, args.model, is_vllm=self.is_vllm)
plan_checker = PlanStepChecker(self.llm, is_vllm=self.is_vllm)

plan_step = Planner(self.llm_endpoint, args.model, plan_checker, is_vllm=self.is_vllm)
execute_step = Executor(self.llm_endpoint, args.model, self.tools_descriptions)
make_answer = AnswerMaker(self.llm_endpoint, args.model, is_vllm=self.is_vllm)
plan_step = Planner(self.llm, plan_checker, is_vllm=self.is_vllm)
execute_step = Executor(self.llm, self.tools_descriptions)
make_answer = AnswerMaker(self.llm, is_vllm=self.is_vllm)

# Define Graph
workflow = StateGraph(PlanExecute)
Expand Down
Loading

0 comments on commit c8e3639

Please sign in to comment.