Skip to content

Commit

Permalink
Add RAG agent and ReAct agent implemention for llama3.1 served by TGI…
Browse files Browse the repository at this point in the history
…-gaudi (#722)

* add ragagent and react agent for llama3.1

Signed-off-by: minmin-intel <[email protected]>

* update ut

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test

Signed-off-by: minmin-intel <[email protected]>

* update test

Signed-off-by: minmin-intel <[email protected]>

* debug ut

Signed-off-by: minmin-intel <[email protected]>

* update test

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test and readme

Signed-off-by: minmin-intel <[email protected]>

* update ragagent llama docgrader

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: minmin-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
minmin-intel and pre-commit-ci[bot] authored Sep 26, 2024
1 parent cc80c1b commit e7fdf53
Show file tree
Hide file tree
Showing 15 changed files with 669 additions and 38 deletions.
13 changes: 11 additions & 2 deletions comps/agent/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ This agent microservice is built on Langchain/Langgraph frameworks. Agents integ

We currently support the following types of agents:

1. ReAct: use `react_langchain` or `react_langgraph` as strategy. First introduced in this seminal [paper](https://arxiv.org/abs/2210.03629). The ReAct agent engages in "reason-act-observe" cycles to solve problems. Please refer to this [doc](https://python.langchain.com/v0.2/docs/how_to/migrate_agent/) to understand the differences between the langchain and langgraph versions of react agents.
2. RAG agent: `rag_agent` strategy. This agent is specifically designed for improving RAG performance. It has the capability to rephrase query, check relevancy of retrieved context, and iterate if context is not relevant.
1. ReAct: use `react_langchain` or `react_langgraph` or `react_llama` as strategy. First introduced in this seminal [paper](https://arxiv.org/abs/2210.03629). The ReAct agent engages in "reason-act-observe" cycles to solve problems. Please refer to this [doc](https://python.langchain.com/v0.2/docs/how_to/migrate_agent/) to understand the differences between the langchain and langgraph versions of react agents. See table below to understand the validated LLMs for each react strategy.
2. RAG agent: use `rag_agent` or `rag_agent_llama` strategy. This agent is specifically designed for improving RAG performance. It has the capability to rephrase query, check relevancy of retrieved context, and iterate if context is not relevant. See table below to understand the validated LLMs for each rag agent strategy.
3. Plan and execute: `plan_execute` strategy. This type of agent first makes a step-by-step plan given a user request, and then execute the plan sequentially (or in parallel, to be implemented in future). If the execution results can solve the problem, then the agent will output an answer; otherwise, it will replan and execute again.
For advanced developers who want to implement their own agent strategies, please refer to [Section 5](#5-customize-agent-strategy) below.

Expand All @@ -20,6 +20,15 @@ Agents use LLM for reasoning and planning. We support 2 options of LLM engine:
1. Open-source LLMs served with TGI-gaudi. To use open-source llms, follow the instructions in [Section 2](#222-start-microservices) below. Note: we recommend using state-of-the-art LLMs, such as llama3.1-70B-instruct, to get higher success rate.
2. OpenAI LLMs via API calls. To use OpenAI llms, specify `llm_engine=openai` and `export OPENAI_API_KEY=<your-openai-key>`

| Agent type | `strategy` arg | Validated LLMs | Notes |
| ---------------- | ----------------- | ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------ |
| ReAct | `react_langchain` | GPT-4o-mini, [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) | Only allows tools with one input variable |
| ReAct | `react_langgraph` | GPT-4o-mini | Currently does not work for open-source LLMs served with TGI-Gaudi |
| ReAct | `react_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) | Recommended for open-source LLMs served with TGI-Gaudi |
| RAG agent | `rag_agent` | GPT-4o-mini | Currently does not work for open-source LLMs served with TGI-Gaudi |
| RAG agent | `rag_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) | Recommended for open-source LLMs served with TGI-Gaudi, only allows 1 tool with input variable to be "query" |
| Plan and execute | `plan_execute` | GPT-4o-mini | |

### 1.3 Tools

The tools are registered with a yaml file. We support the following types of tools:
Expand Down
1 change: 1 addition & 0 deletions comps/agent/langchain/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]):
logger.info(f"args: {args}")
input.streaming = args.streaming
config = {"recursion_limit": args.recursion_limit}
print("========initiating agent============")
agent_inst = instantiate_agent(args, args.strategy)
if logflag:
logger.info(type(agent_inst))
Expand Down
8 changes: 7 additions & 1 deletion comps/agent/langchain/src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False):
from .strategy.react import ReActAgentwithLanggraph

return ReActAgentwithLanggraph(args, with_memory)
elif strategy == "react_llama":
print("Initializing ReAct Agent with LLAMA")
from .strategy.react import ReActAgentLlama

return ReActAgentLlama(args, with_memory)
elif strategy == "plan_execute":
from .strategy.planexec import PlanExecuteAgentWithLangGraph

return PlanExecuteAgentWithLangGraph(args, with_memory)

elif strategy == "rag_agent":
elif strategy == "rag_agent" or strategy == "rag_agent_llama":
print("Initializing RAG Agent")
from .strategy.ragagent import RAGAgent

return RAGAgent(args, with_memory)
Expand Down
146 changes: 136 additions & 10 deletions comps/agent/langchain/src/strategy/ragagent/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Annotated, Any, Literal, Sequence, TypedDict

from langchain.output_parsers import PydanticOutputParser
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
Expand All @@ -17,7 +17,7 @@
from langgraph.prebuilt import ToolNode, tools_condition

from ..base_agent import BaseAgent
from .prompt import DOC_GRADER_PROMPT, RAG_PROMPT
from .prompt import DOC_GRADER_PROMPT, RAG_PROMPT, QueryWriterLlamaPrompt

instruction = "Retrieved document is not sufficient or relevant to answer the query. Reformulate the query to search knowledge base again."
MAX_RETRY = 3
Expand Down Expand Up @@ -61,6 +61,8 @@ def __call__(self, state):
class Retriever:
@classmethod
def create(cls, tools_descriptions):
for tool in tools_descriptions:
print(tool.name)
return ToolNode(tools_descriptions)


Expand Down Expand Up @@ -132,20 +134,23 @@ def __init__(self, llm_endpoint, model_id=None):
self.rag_chain = prompt | llm_endpoint | StrOutputParser()

def __call__(self, state):
from .utils import aggregate_docs

print("---GENERATE---")
messages = state["messages"]
question = messages[0].content
query_time = state["query_time"]

# find the latest retrieved doc
# which is a ToolMessage
for m in state["messages"][::-1]:
if isinstance(m, ToolMessage):
last_message = m
break
# for m in state["messages"][::-1]:
# if isinstance(m, ToolMessage):
# last_message = m
# break
# docs = last_message.content

question = messages[0].content
docs = last_message.content
docs = aggregate_docs(messages)

# Run
response = self.rag_chain.invoke({"context": docs, "question": question, "time": query_time})
Expand All @@ -159,8 +164,13 @@ def __init__(self, args, with_memory=False):
super().__init__(args)

# Define Nodes
document_grader = DocumentGrader(self.llm_endpoint, args.model)
query_writer = QueryWriter(self.llm_endpoint, args.model, self.tools_descriptions)

if args.strategy == "rag_agent":
query_writer = QueryWriter(self.llm_endpoint, args.model, self.tools_descriptions)
document_grader = DocumentGrader(self.llm_endpoint, args.model)
elif args.strategy == "rag_agent_llama":
query_writer = QueryWriterLlama(self.llm_endpoint, args.model, self.tools_descriptions)
document_grader = DocumentGraderLlama(self.llm_endpoint, args.model)
text_generator = TextGenerator(self.llm_endpoint)
retriever = Retriever.create(self.tools_descriptions)

Expand Down Expand Up @@ -248,3 +258,119 @@ async def non_streaming_run(self, query, config):
return last_message.content
except Exception as e:
return str(e)


class QueryWriterLlama:
"""Temporary workaround to use LLM with TGI-Gaudi.
Use custom output parser to parse text string from LLM into tool calls.
Only support one tool. Does NOT support multiple tools.
The tool input variable must be "query".
Only validated with llama3.1-70B-instruct.
Output of the chain is AIMessage.
Streaming=false is required for this chain.
"""

def __init__(self, llm_endpoint, model_id, tools):
from .utils import QueryWriterLlamaOutputParser

assert len(tools) == 1, "Only support one tool, passed in {} tools".format(len(tools))
output_parser = QueryWriterLlamaOutputParser()
prompt = PromptTemplate(
template=QueryWriterLlamaPrompt,
input_variables=["question", "history", "feedback"],
)
llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id)
self.tools = tools
self.chain = prompt | llm | output_parser

def __call__(self, state):
from .utils import assemble_history, convert_json_to_tool_call

print("---CALL QueryWriter---")
messages = state["messages"]

question = messages[0].content
history = assemble_history(messages)
feedback = instruction

response = self.chain.invoke({"question": question, "history": history, "feedback": feedback})
print("Response from query writer llm: ", response)

### Code below assumes one tool call in the response ##############
# if "query" in response:
# add_kw_tc, tool_call = convert_json_to_tool_call(response, self.tools[0])
# # print("Tool call:\n", tool_call)
# response = AIMessage(content="", additional_kwargs=add_kw_tc, tool_calls=[tool_call])
# # print(response)
# else:
# response = AIMessage(content=response["answer"])
# We return a list, because this will get added to the existing list
# return {"messages": [response], "output": response}
######################################################################

############ allow multiple tool calls in one AI message ############
tool_calls = []
for res in response:
if "query" in res:
add_kw_tc, tool_call = convert_json_to_tool_call(res, self.tools[0])
# print("Tool call:\n", tool_call)
tool_calls.append(tool_call)

if tool_calls:
ai_message = AIMessage(content="", additional_kwargs=add_kw_tc, tool_calls=tool_calls)
else:
ai_message = AIMessage(content=response[0]["answer"])

return {"messages": [ai_message], "output": ai_message.content}


class DocumentGraderLlama:
"""Determines whether the retrieved documents are relevant to the question.
Args:
state (messages): The current state
Returns:
str: A decision for whether the documents are relevant or not
"""

def __init__(self, llm_endpoint, model_id=None):
from .prompt import DOC_GRADER_Llama_PROMPT

# Prompt
prompt = PromptTemplate(
template=DOC_GRADER_Llama_PROMPT,
input_variables=["context", "question"],
)

if isinstance(llm_endpoint, HuggingFaceEndpoint):
llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id)
elif isinstance(llm_endpoint, ChatOpenAI):
llm = llm_endpoint
self.chain = prompt | llm

def __call__(self, state) -> Literal["generate", "rewrite"]:
from .utils import aggregate_docs

print("---CALL DocumentGrader---")
messages = state["messages"]

question = messages[0].content # the original query
docs = aggregate_docs(messages)
print("@@@@ Docs: ", docs)

scored_result = self.chain.invoke({"question": question, "context": docs})

score = scored_result.content
print("@@@@ Score: ", score)

# if score.startswith("yes"):
if "yes" in score:
print("---DECISION: DOCS RELEVANT---")
return {"doc_score": "generate"}

else:
print("---DECISION: DOCS NOT RELEVANT---")

return {"messages": [HumanMessage(content=instruction)], "doc_score": "rewrite"}
32 changes: 32 additions & 0 deletions comps/agent/langchain/src/strategy/ragagent/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,35 @@
),
]
)


QueryWriterLlamaPrompt = """\
Given the user question, think step by step.
If you can answer the question without searching the knowledge base, provide your answer.
If you need to search for information in the knowledge base, provide the search query.
Decompose a complex question into a set of simple tasks, and issue search queries for each task.
Here is the history of search queries that you have issued.
{history}
Here are the feedback for the documents retrieved with your search queries.
{feedback}
What is the new query that you should issue to the knowledge base to answer the user question?
Output the new query in JSON format as below.
{{"query": "your new query here"}}
If you plan to issue multiple queries, you must output JSON in multiple lines like the example below.
{{"query": "your first query here"}}
{{"query": "your second query here"}}
If you can directly answer the user question, output your answer in JSON format as below.
{{"answer": "your answer here"}}
User Question: {question}
You Output:\n
"""

DOC_GRADER_Llama_PROMPT = """\
Given the QUERY, determine if the DOCUMENT contains all the information to answer the query.\n
QUERY: {question} \n
DOCUMENT:\n{context}\n\n
Give score 'yes' if the document provides all the information needed to answer the question. Otherwise, give score 'no'. ONLY answer with 'yes' or 'no'. NOTHING ELSE."""
78 changes: 78 additions & 0 deletions comps/agent/langchain/src/strategy/ragagent/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import uuid

from huggingface_hub import ChatCompletionOutputFunctionDefinition, ChatCompletionOutputToolCall
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.output_parsers import BaseOutputParser


class QueryWriterLlamaOutputParser(BaseOutputParser):
def parse(self, text: str):
print("raw output from llm: ", text)
json_lines = text.split("\n")
print("json_lines: ", json_lines)
output = []
for line in json_lines:
try:
output.append(json.loads(line))
except Exception as e:
print("Exception happened in output parsing: ", str(e))
if output:
return output
else:
return None


def convert_json_to_tool_call(json_str, tool):
tool_name = tool.name
tcid = str(uuid.uuid4())
add_kw_tc = {
"tool_calls": [
ChatCompletionOutputToolCall(
function=ChatCompletionOutputFunctionDefinition(
arguments={"query": json_str["query"]}, name=tool_name, description=None
),
id=tcid,
type="function",
)
]
}
tool_call = ToolCall(name=tool_name, args={"query": json_str["query"]}, id=tcid)
return add_kw_tc, tool_call


def assemble_history(messages):
"""
messages: AI (query writer), TOOL (retriever), HUMAN (Doc Grader), AI, TOOL, HUMAN, etc.
"""
query_history = ""
n = 1
for m in messages[1:]: # exclude the first message
if isinstance(m, AIMessage):
# if there is tool call
if hasattr(m, "tool_calls") and len(m.tool_calls) > 0:
for tool_call in m.tool_calls:
query = tool_call["args"]["query"]
query_history += f"{n}. {query}\n"
n += 1
return query_history


def aggregate_docs(messages):
"""
messages: AI (query writer), TOOL (retriever), HUMAN (Doc Grader
"""
docs = []
context = ""
for m in messages[::-1]:
if isinstance(m, ToolMessage):
docs.append(m.content)
elif isinstance(m, AIMessage):
break
for doc in docs[::-1]:
context = context + doc + "\n"
return context
1 change: 1 addition & 0 deletions comps/agent/langchain/src/strategy/react/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@

from .planner import ReActAgentwithLangchain
from .planner import ReActAgentwithLanggraph
from .planner import ReActAgentLlama
Loading

0 comments on commit e7fdf53

Please sign in to comment.