diff --git a/comps/agent/langchain/agent.py b/comps/agent/langchain/agent.py index f8b3e8377a..fffdc8765d 100644 --- a/comps/agent/langchain/agent.py +++ b/comps/agent/langchain/agent.py @@ -12,7 +12,7 @@ comps_path = os.path.join(cur_path, "../../../") sys.path.append(comps_path) -from comps import LLMParamsDoc, GeneratedDoc, ServiceType, opea_microservices, register_microservice +from comps import GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice from comps.agent.langchain.src.agent import instantiate_agent from comps.agent.langchain.src.utils import get_args @@ -35,19 +35,19 @@ async def llm_generate(input: LLMParamsDoc): agent_inst = instantiate_agent(args, args.strategy) print(type(agent_inst)) - # 2. prepare the input for the agent if input.streaming: - print('-----------STREAMING-------------') + print("-----------STREAMING-------------") return StreamingResponse(agent_inst.stream_generator(input.query, config), media_type="text/event-stream") else: # TODO: add support for non-streaming mode - print('-----------NOT STREAMING-------------') + print("-----------NOT STREAMING-------------") response = await agent_inst.non_streaming_run(input.query, config) - print('-----------Response-------------') + print("-----------Response-------------") print(response) return GeneratedDoc(text=response, prompt=input.query) + if __name__ == "__main__": opea_microservices["opea_service@comps-react-agent"].start() diff --git a/comps/agent/langchain/src/strategy/base_agent.py b/comps/agent/langchain/src/strategy/base_agent.py index fef26c9aea..f9e8fed9e2 100644 --- a/comps/agent/langchain/src/strategy/base_agent.py +++ b/comps/agent/langchain/src/strategy/base_agent.py @@ -19,4 +19,4 @@ def execute(self, state: dict): pass def non_streaming_run(self, query, config): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/comps/agent/langchain/src/strategy/docgrader/planner.py b/comps/agent/langchain/src/strategy/docgrader/planner.py index cbab0e1d68..fe83d53f99 100644 --- a/comps/agent/langchain/src/strategy/docgrader/planner.py +++ b/comps/agent/langchain/src/strategy/docgrader/planner.py @@ -10,17 +10,17 @@ from langchain_core.prompts import PromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint +from langchain_openai import ChatOpenAI from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition -from langchain_openai import ChatOpenAI from ..base_agent import BaseAgent from .prompt import DOC_GRADER_PROMPT, RAGv1_PROMPT +instruction = "Retrieved document is not sufficient or relevant to answer the query. Reformulate the query to search knowledge base again." +MAX_RETRY = 3 -instruction="Retrieved document is not sufficient or relevant to answer the query. Reformulate the query to search knowledge base again." -MAX_RETRY=3 class AgentStateV1(TypedDict): # The add_messages function defines how an update should be processed @@ -30,6 +30,7 @@ class AgentStateV1(TypedDict): doc_score: str query_time: str + class RagAgent: """Invokes the agent model to generate a response based on the current state. Given the question, it will decide to retrieve using the retriever tool, or simply end. @@ -61,6 +62,7 @@ class Retriever: def create(cls, tools_descriptions): return ToolNode(tools_descriptions) + class DocumentGraderV1: """Determines whether the retrieved documents are relevant to the question. @@ -93,9 +95,9 @@ class grade(BaseModel): def __call__(self, state) -> Literal["generate", "rewrite"]: print("---CALL DocumentGrader---") messages = state["messages"] - last_message = messages[-1] # the latest retrieved doc + last_message = messages[-1] # the latest retrieved doc - question = messages[0].content # the original query + question = messages[0].content # the original query docs = last_message.content scored_result = self.chain.invoke({"question": question, "context": docs}) @@ -104,12 +106,13 @@ def __call__(self, state) -> Literal["generate", "rewrite"]: if score.startswith("yes"): print("---DECISION: DOCS RELEVANT---") - return {"doc_score":"generate"} + return {"doc_score": "generate"} else: print(f"---DECISION: DOCS NOT RELEVANT, score is {score}---") - - return {"messages":[HumanMessage(content=instruction)], "doc_score": "rewrite"} + + return {"messages": [HumanMessage(content=instruction)], "doc_score": "rewrite"} + class TextGeneratorV1: """Generate answer. @@ -131,11 +134,11 @@ def __call__(self, state): print("---GENERATE---") messages = state["messages"] question = messages[0].content - query_time = state['query_time'] + query_time = state["query_time"] # find the latest retrieved doc # which is a ToolMessage - for m in state['messages'][::-1]: + for m in state["messages"][::-1]: if isinstance(m, ToolMessage): last_message = m break @@ -144,11 +147,11 @@ def __call__(self, state): docs = last_message.content # Run - response = self.rag_chain.invoke({"context": docs, "question": question, "time":query_time}) - print('@@@@ Used this doc for generation:\n', docs) - print('@@@@ Generated response: ', response) + response = self.rag_chain.invoke({"context": docs, "question": question, "time": query_time}) + print("@@@@ Used this doc for generation:\n", docs) + print("@@@@ Generated response: ", response) return {"messages": [response], "output": response} - + class RAGAgentDocGraderV1(BaseAgent): def __init__(self, args): @@ -186,28 +189,28 @@ def __init__(self, args): "doc_grader", self.should_retry, { - False: "generate", - True: "agent", + False: "generate", + True: "agent", }, ) workflow.add_edge("generate", END) self.app = workflow.compile() - + def should_retry(self, state): # first check how many retry attempts have been made num_retry = 0 - for m in state['messages']: + for m in state["messages"]: if instruction in m.content: num_retry += 1 print("**********Num retry: ", num_retry) - - if (num_retry