Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 13, 2024
1 parent d7469f1 commit 87a8080
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 48 deletions.
10 changes: 5 additions & 5 deletions comps/agent/langchain/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
comps_path = os.path.join(cur_path, "../../../")
sys.path.append(comps_path)

from comps import LLMParamsDoc, GeneratedDoc, ServiceType, opea_microservices, register_microservice
from comps import GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
from comps.agent.langchain.src.agent import instantiate_agent
from comps.agent.langchain.src.utils import get_args

Expand All @@ -35,19 +35,19 @@ async def llm_generate(input: LLMParamsDoc):
agent_inst = instantiate_agent(args, args.strategy)
print(type(agent_inst))


# 2. prepare the input for the agent
if input.streaming:
print('-----------STREAMING-------------')
print("-----------STREAMING-------------")
return StreamingResponse(agent_inst.stream_generator(input.query, config), media_type="text/event-stream")

else:
# TODO: add support for non-streaming mode
print('-----------NOT STREAMING-------------')
print("-----------NOT STREAMING-------------")
response = await agent_inst.non_streaming_run(input.query, config)
print('-----------Response-------------')
print("-----------Response-------------")
print(response)
return GeneratedDoc(text=response, prompt=input.query)


if __name__ == "__main__":
opea_microservices["opea_service@comps-react-agent"].start()
2 changes: 1 addition & 1 deletion comps/agent/langchain/src/strategy/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ def execute(self, state: dict):
pass

def non_streaming_run(self, query, config):
raise NotImplementedError
raise NotImplementedError
56 changes: 27 additions & 29 deletions comps/agent/langchain/src/strategy/docgrader/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_openai import ChatOpenAI

from ..base_agent import BaseAgent
from .prompt import DOC_GRADER_PROMPT, RAGv1_PROMPT

instruction = "Retrieved document is not sufficient or relevant to answer the query. Reformulate the query to search knowledge base again."
MAX_RETRY = 3

instruction="Retrieved document is not sufficient or relevant to answer the query. Reformulate the query to search knowledge base again."
MAX_RETRY=3

class AgentStateV1(TypedDict):
# The add_messages function defines how an update should be processed
Expand All @@ -30,6 +30,7 @@ class AgentStateV1(TypedDict):
doc_score: str
query_time: str


class RagAgent:
"""Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Expand Down Expand Up @@ -61,6 +62,7 @@ class Retriever:
def create(cls, tools_descriptions):
return ToolNode(tools_descriptions)


class DocumentGraderV1:
"""Determines whether the retrieved documents are relevant to the question.
Expand Down Expand Up @@ -93,9 +95,9 @@ class grade(BaseModel):
def __call__(self, state) -> Literal["generate", "rewrite"]:
print("---CALL DocumentGrader---")
messages = state["messages"]
last_message = messages[-1] # the latest retrieved doc
last_message = messages[-1] # the latest retrieved doc

question = messages[0].content # the original query
question = messages[0].content # the original query
docs = last_message.content

scored_result = self.chain.invoke({"question": question, "context": docs})
Expand All @@ -104,12 +106,13 @@ def __call__(self, state) -> Literal["generate", "rewrite"]:

if score.startswith("yes"):
print("---DECISION: DOCS RELEVANT---")
return {"doc_score":"generate"}
return {"doc_score": "generate"}

else:
print(f"---DECISION: DOCS NOT RELEVANT, score is {score}---")

return {"messages":[HumanMessage(content=instruction)], "doc_score": "rewrite"}

return {"messages": [HumanMessage(content=instruction)], "doc_score": "rewrite"}


class TextGeneratorV1:
"""Generate answer.
Expand All @@ -131,11 +134,11 @@ def __call__(self, state):
print("---GENERATE---")
messages = state["messages"]
question = messages[0].content
query_time = state['query_time']
query_time = state["query_time"]

# find the latest retrieved doc
# which is a ToolMessage
for m in state['messages'][::-1]:
for m in state["messages"][::-1]:
if isinstance(m, ToolMessage):
last_message = m
break
Expand All @@ -144,11 +147,11 @@ def __call__(self, state):
docs = last_message.content

# Run
response = self.rag_chain.invoke({"context": docs, "question": question, "time":query_time})
print('@@@@ Used this doc for generation:\n', docs)
print('@@@@ Generated response: ', response)
response = self.rag_chain.invoke({"context": docs, "question": question, "time": query_time})
print("@@@@ Used this doc for generation:\n", docs)
print("@@@@ Generated response: ", response)
return {"messages": [response], "output": response}


class RAGAgentDocGraderV1(BaseAgent):
def __init__(self, args):
Expand Down Expand Up @@ -186,28 +189,28 @@ def __init__(self, args):
"doc_grader",
self.should_retry,
{
False: "generate",
True: "agent",
False: "generate",
True: "agent",
},
)
workflow.add_edge("generate", END)

self.app = workflow.compile()

def should_retry(self, state):
# first check how many retry attempts have been made
num_retry = 0
for m in state['messages']:
for m in state["messages"]:
if instruction in m.content:
num_retry += 1

print("**********Num retry: ", num_retry)
if (num_retry <MAX_RETRY) and (state["doc_score"] == "rewrite"):

if (num_retry < MAX_RETRY) and (state["doc_score"] == "rewrite"):
return True
else:
return False

def prepare_initial_state(self, query):
return {"messages": [HumanMessage(content=query)]}

Expand All @@ -225,7 +228,7 @@ async def stream_generator(self, query, config):
yield "data: [DONE]\n\n"
except Exception as e:
yield str(e)

async def non_streaming_run(self, query, config):
initial_state = self.prepare_initial_state(query)
try:
Expand All @@ -236,13 +239,8 @@ async def non_streaming_run(self, query, config):
else:
message.pretty_print()

last_message = s['messages'][-1]
print('******Response: ', last_message.content)
last_message = s["messages"][-1]
print("******Response: ", last_message.content)
return last_message.content
except Exception as e:
return str(e)





2 changes: 1 addition & 1 deletion comps/agent/langchain/src/strategy/docgrader/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from langchain_core.prompts import ChatPromptTemplate, PromptTemplate

DOC_GRADER_PROMPT="""\
DOC_GRADER_PROMPT = """\
Given the QUERY, determine if a relevant answer can be derived from the DOCUMENT.\n
QUERY: {question} \n
DOCUMENT:\n{context}\n\n
Expand Down
21 changes: 11 additions & 10 deletions comps/agent/langchain/src/strategy/react/planner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from langchain.agents import AgentExecutor
from langchain.agents import AgentExecutor
from langchain.agents import create_react_agent as create_react_langchain_agent
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent

from ...utils import has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent
from .prompt import hwchase17_react_prompt
from .prompt import REACT_SYS_MESSAGE
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt


class ReActAgentwithLangchain(BaseAgent):
Expand Down Expand Up @@ -48,6 +48,7 @@ async def stream_generator(self, query, config):
print("---")
yield "data: [DONE]\n\n"


class ReActAgentwithLanggraph(BaseAgent):
def __init__(self, args):
super().__init__(args)
Expand All @@ -59,11 +60,11 @@ def __init__(self, args):

tools = self.tools_descriptions

self.app = create_react_agent(self.llm, tools = tools, state_modifier=REACT_SYS_MESSAGE)
self.app = create_react_agent(self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE)

def prepare_initial_state(self, query):
return {"messages": [HumanMessage(content=query)]}

async def stream_generator(self, query, config):
initial_state = self.prepare_initial_state(query)
try:
Expand All @@ -78,7 +79,7 @@ async def stream_generator(self, query, config):
yield "data: [DONE]\n\n"
except Exception as e:
yield str(e)

async def non_streaming_run(self, query, config):
initial_state = self.prepare_initial_state(query)
try:
Expand All @@ -89,8 +90,8 @@ async def non_streaming_run(self, query, config):
else:
message.pretty_print()

last_message = s['messages'][-1]
print('******Response: ', last_message.content)
last_message = s["messages"][-1]
print("******Response: ", last_message.content)
return last_message.content
except Exception as e:
return str(e)
return str(e)
4 changes: 2 additions & 2 deletions comps/agent/langchain/src/strategy/react/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
)


REACT_SYS_MESSAGE="""\
REACT_SYS_MESSAGE = """\
Decompose the user request into a series of simple tasks when necessary and solve the problem step by step.
When you cannot get the answer at first, do not give up. Reflect on the info you have from the tools and try to solve the problem in a different way.
Please follow these guidelines when formulating your answer:
1. If the question contains a false premise or assumption, answer “invalid question”.
2. If you are uncertain or do not know the answer, respond with “I don’t know”.
3. Give concise, factual and relevant answers.
"""
"""

0 comments on commit 87a8080

Please sign in to comment.