diff --git a/comps/agent/langchain/README.md b/comps/agent/langchain/README.md index b0d894e19..cc03aad17 100644 --- a/comps/agent/langchain/README.md +++ b/comps/agent/langchain/README.md @@ -11,23 +11,28 @@ We currently support the following types of agents: 1. ReAct: use `react_langchain` or `react_langgraph` or `react_llama` as strategy. First introduced in this seminal [paper](https://arxiv.org/abs/2210.03629). The ReAct agent engages in "reason-act-observe" cycles to solve problems. Please refer to this [doc](https://python.langchain.com/v0.2/docs/how_to/migrate_agent/) to understand the differences between the langchain and langgraph versions of react agents. See table below to understand the validated LLMs for each react strategy. 2. RAG agent: use `rag_agent` or `rag_agent_llama` strategy. This agent is specifically designed for improving RAG performance. It has the capability to rephrase query, check relevancy of retrieved context, and iterate if context is not relevant. See table below to understand the validated LLMs for each rag agent strategy. 3. Plan and execute: `plan_execute` strategy. This type of agent first makes a step-by-step plan given a user request, and then execute the plan sequentially (or in parallel, to be implemented in future). If the execution results can solve the problem, then the agent will output an answer; otherwise, it will replan and execute again. - For advanced developers who want to implement their own agent strategies, please refer to [Section 5](#5-customize-agent-strategy) below. + +**Note**: + +1. Due to the limitations in support for tool calling by TGI and vllm, we have developed subcategories of agent strategies (`rag_agent_llama` and `react_llama`) specifically designed for open-source LLMs served with TGI and vllm. +2. For advanced developers who want to implement their own agent strategies, please refer to [Section 5](#5-customize-agent-strategy) below. ### 1.2 LLM engine -Agents use LLM for reasoning and planning. We support 2 options of LLM engine: +Agents use LLM for reasoning and planning. We support 3 options of LLM engine: -1. Open-source LLMs served with TGI-gaudi. To use open-source llms, follow the instructions in [Section 2](#222-start-microservices) below. Note: we recommend using state-of-the-art LLMs, such as llama3.1-70B-instruct, to get higher success rate. -2. OpenAI LLMs via API calls. To use OpenAI llms, specify `llm_engine=openai` and `export OPENAI_API_KEY=` +1. Open-source LLMs served with TGI. Follow the instructions in [Section 2.2.1](#221-start-agent-microservices-with-tgi). +2. Open-source LLMs served with vllm. Follow the instructions in [Section 2.2.2](#222-start-agent-microservices-with-vllm). +3. OpenAI LLMs via API calls. To use OpenAI llms, specify `llm_engine=openai` and `export OPENAI_API_KEY=` -| Agent type | `strategy` arg | Validated LLMs | Notes | -| ---------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| ReAct | `react_langchain` | GPT-4o-mini, [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) | Only allows tools with one input variable | -| ReAct | `react_langgraph` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)-on-vllm, | Currently does not work for open-source LLMs served with TGI-Gaudi, [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)-on-vllm is not synced from vllm upstream to gaudi repo yet. | -| ReAct | `react_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) | Recommended for open-source LLMs served with TGI-Gaudi | -| RAG agent | `rag_agent` | GPT-4o-mini | Currently does not work for open-source LLMs served with TGI-Gaudi | -| RAG agent | `rag_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) | Recommended for open-source LLMs served with TGI-Gaudi, only allows 1 tool with input variable to be "query" | -| Plan and execute | `plan_execute` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)-on-vllm, [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)-on-vllm | | +| Agent type | `strategy` arg | Validated LLMs (serving SW) | Notes | +| ---------------- | ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | +| ReAct | `react_langchain` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Only allows tools with one input variable | +| ReAct | `react_langgraph` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), | if using vllm, need to specify `--enable-auto-tool-choice --tool-call-parser ${model_parser}`, refer to vllm docs for more info | +| ReAct | `react_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Recommended for open-source LLMs | +| RAG agent | `rag_agent` | GPT-4o-mini | | +| RAG agent | `rag_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Recommended for open-source LLMs, only allows 1 tool with input variable to be "query" | +| Plan and execute | `plan_execute` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) (vllm-gaudi) | | ### 1.3 Tools @@ -36,11 +41,15 @@ The tools are registered with a yaml file. We support the following types of too 1. Endpoint: user to provide url 2. User-defined python functions. This is usually used to wrap endpoints with request post or simple pre/post-processing. 3. Langchain tool modules. - Examples of how to register tools can be found in [Section 4](#-4-provide-your-own-tools) below. + +Examples of how to register tools can be found in [Section 4](#-4-provide-your-own-tools) below. ### 1.4 Agent APIs -Currently we have implemented OpenAI chat completion compatible API for agents. We are working to support OpenAI assistants APIs. +1. OpenAI compatible chat completions API +2. OpenAI compatible assistants APIs. + +**Note**: not all keywords are supported yet. ## 🚀2. Start Agent Microservice diff --git a/comps/agent/langchain/src/strategy/base_agent.py b/comps/agent/langchain/src/strategy/base_agent.py index 7cb36ff99..beb4fa9f8 100644 --- a/comps/agent/langchain/src/strategy/base_agent.py +++ b/comps/agent/langchain/src/strategy/base_agent.py @@ -4,12 +4,12 @@ from uuid import uuid4 from ..tools import get_tools_descriptions -from ..utils import adapt_custom_prompt, setup_llm +from ..utils import adapt_custom_prompt, setup_chat_model class BaseAgent: def __init__(self, args, local_vars=None, **kwargs) -> None: - self.llm_endpoint = setup_llm(args) + self.llm = setup_chat_model(args) self.tools_descriptions = get_tools_descriptions(args.tools) self.app = None self.memory = None diff --git a/comps/agent/langchain/src/strategy/planexec/planner.py b/comps/agent/langchain/src/strategy/planexec/planner.py index 2e2726330..ad4beb0cc 100644 --- a/comps/agent/langchain/src/strategy/planexec/planner.py +++ b/comps/agent/langchain/src/strategy/planexec/planner.py @@ -14,7 +14,7 @@ from pydantic import BaseModel, Field from ...global_var import threads_global_kv -from ...utils import has_multi_tool_inputs, tool_renderer, wrap_chat +from ...utils import has_multi_tool_inputs, tool_renderer from ..base_agent import BaseAgent from .prompt import ( answer_check_prompt, @@ -25,8 +25,6 @@ replanner_prompt, ) -# Define protocol - class PlanExecute(TypedDict): messages: Annotated[Sequence[BaseMessage], add_messages] @@ -56,16 +54,14 @@ class PlanStepChecker: str: A decision for whether we should use this plan or not """ - def __init__(self, llm_endpoint, model_id=None, is_vllm=False): + def __init__(self, llm, is_vllm=False): class grade(BaseModel): binary_score: str = Field(description="executable score 'yes' or 'no'") if is_vllm: - llm = wrap_chat(llm_endpoint, model_id).bind_tools( - [grade], tool_choice={"function": {"name": grade.__name__}} - ) + llm = llm.bind_tools([grade], tool_choice={"function": {"name": grade.__name__}}) else: - llm = wrap_chat(llm_endpoint, model_id).bind_tools([grade]) + llm = llm.bind_tools([grade]) output_parser = PydanticToolsParser(tools=[grade], first_tool_only=True) self.chain = plan_check_prompt | llm | output_parser @@ -74,7 +70,7 @@ def __call__(self, state): # print("---CALL PlanStepChecker---") scored_result = self.chain.invoke(state) score = scored_result.binary_score - # print(f"Task is {state['context']}, Greade of relevance to question is {score}") + print(f"Task is {state['context']}, Score is {score}") if score.startswith("yes"): return True else: @@ -83,13 +79,11 @@ def __call__(self, state): # Define workflow Node class Planner: - def __init__(self, llm_endpoint, model_id=None, plan_checker=None, is_vllm=False): + def __init__(self, llm, plan_checker=None, is_vllm=False): if is_vllm: - llm = wrap_chat(llm_endpoint, model_id).bind_tools( - [Plan], tool_choice={"function": {"name": Plan.__name__}} - ) + llm = llm.bind_tools([Plan], tool_choice={"function": {"name": Plan.__name__}}) else: - llm = wrap_chat(llm_endpoint, model_id).bind_tools([Plan]) + llm = llm.bind_tools([Plan]) output_parser = PydanticToolsParser(tools=[Plan], first_tool_only=True) self.llm = planner_prompt | llm | output_parser self.plan_checker = plan_checker @@ -103,6 +97,7 @@ def __call__(self, state): while not success: try: plan = self.llm.invoke({"messages": [("user", state["messages"][-1].content)]}) + print("Generated plan: ", plan) success = True except OutputParserException as e: pass @@ -116,17 +111,17 @@ def __call__(self, state): if len(steps) == 0: success = False - + print("Steps: ", steps) return {"input": input, "plan": steps} class Executor: - def __init__(self, llm_endpoint, model_id=None, tools=[]): + def __init__(self, llm, tools=[]): prompt = hwchase17_react_prompt if has_multi_tool_inputs(tools): raise ValueError("Only supports single input tools when using strategy == react") else: - agent_chain = create_react_agent(llm_endpoint, tools, prompt, tools_renderer=tool_renderer) + agent_chain = create_react_agent(llm, tools, prompt, tools_renderer=tool_renderer) self.agent_executor = AgentExecutor( agent=agent_chain, tools=tools, handle_parsing_errors=True, max_iterations=50 ) @@ -148,6 +143,7 @@ def __call__(self, state): agent_response = self.agent_executor.invoke({"input": task_formatted}) output = agent_response["output"] success = True + print(f"Task is {step}, Response is {output}") out_state.append(f"Task is {step}, Response is {output}") return { "past_steps": out_state, @@ -155,13 +151,11 @@ def __call__(self, state): class AnswerMaker: - def __init__(self, llm_endpoint, model_id=None, is_vllm=False): + def __init__(self, llm, is_vllm=False): if is_vllm: - llm = wrap_chat(llm_endpoint, model_id).bind_tools( - [Response], tool_choice={"function": {"name": Response.__name__}} - ) + llm = llm.bind_tools([Response], tool_choice={"function": {"name": Response.__name__}}) else: - llm = wrap_chat(llm_endpoint, model_id).bind_tools([Response]) + llm = llm.bind_tools([Response]) output_parser = PydanticToolsParser(tools=[Response], first_tool_only=True) self.llm = answer_make_prompt | llm | output_parser @@ -172,6 +166,7 @@ def __call__(self, state): while not success: try: output = self.llm.invoke(state) + print("Generated response: ", output.response) success = True except OutputParserException as e: pass @@ -188,16 +183,14 @@ class FinalAnswerChecker: str: A decision for whether we should use this plan or not """ - def __init__(self, llm_endpoint, model_id=None, is_vllm=False): + def __init__(self, llm, is_vllm=False): class grade(BaseModel): binary_score: str = Field(description="executable score 'yes' or 'no'") if is_vllm: - llm = wrap_chat(llm_endpoint, model_id).bind_tools( - [grade], tool_choice={"function": {"name": grade.__name__}} - ) + llm = llm.bind_tools([grade], tool_choice={"function": {"name": grade.__name__}}) else: - llm = wrap_chat(llm_endpoint, model_id).bind_tools([grade]) + llm = llm.bind_tools([grade]) output_parser = PydanticToolsParser(tools=[grade], first_tool_only=True) self.chain = answer_check_prompt | llm | output_parser @@ -213,8 +206,8 @@ def __call__(self, state): class Replanner: - def __init__(self, llm_endpoint, model_id=None, answer_checker=None): - llm = wrap_chat(llm_endpoint, model_id).bind_tools([Plan]) + def __init__(self, llm, answer_checker=None): + llm = llm.bind_tools([Plan]) output_parser = PydanticToolsParser(tools=[Plan], first_tool_only=True) self.llm = replanner_prompt | llm | output_parser self.answer_checker = answer_checker @@ -227,6 +220,7 @@ def __call__(self, state): try: output = self.llm.invoke(state) success = True + print("Replan: ", output) except OutputParserException as e: pass except Exception as e: @@ -240,11 +234,11 @@ def __init__(self, args, with_memory=False, **kwargs): super().__init__(args, local_vars=globals(), **kwargs) # Define Node - plan_checker = PlanStepChecker(self.llm_endpoint, args.model, is_vllm=self.is_vllm) + plan_checker = PlanStepChecker(self.llm, is_vllm=self.is_vllm) - plan_step = Planner(self.llm_endpoint, args.model, plan_checker, is_vllm=self.is_vllm) - execute_step = Executor(self.llm_endpoint, args.model, self.tools_descriptions) - make_answer = AnswerMaker(self.llm_endpoint, args.model, is_vllm=self.is_vllm) + plan_step = Planner(self.llm, plan_checker, is_vllm=self.is_vllm) + execute_step = Executor(self.llm, self.tools_descriptions) + make_answer = AnswerMaker(self.llm, is_vllm=self.is_vllm) # Define Graph workflow = StateGraph(PlanExecute) diff --git a/comps/agent/langchain/src/strategy/ragagent/planner.py b/comps/agent/langchain/src/strategy/ragagent/planner.py index 0d8f0410e..1adb06d7c 100644 --- a/comps/agent/langchain/src/strategy/ragagent/planner.py +++ b/comps/agent/langchain/src/strategy/ragagent/planner.py @@ -4,17 +4,13 @@ from typing import Annotated, Literal, Sequence, TypedDict from langchain_core.messages import AIMessage, BaseMessage, HumanMessage -from langchain_core.output_parsers import StrOutputParser -from langchain_core.output_parsers.openai_tools import PydanticToolsParser from langchain_core.prompts import PromptTemplate -from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint -from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition -from pydantic import BaseModel, Field +from ...utils import setup_chat_model from ..base_agent import BaseAgent from .prompt import DOC_GRADER_PROMPT, RAG_PROMPT, QueryWriterLlamaPrompt @@ -42,11 +38,8 @@ class QueryWriter: dict: The updated state with the response appended to messages """ - def __init__(self, llm_endpoint, model_id, tools): - if isinstance(llm_endpoint, HuggingFaceEndpoint): - self.llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id).bind_tools(tools) - elif isinstance(llm_endpoint, ChatOpenAI): - self.llm = llm_endpoint.bind_tools(tools) + def __init__(self, llm, tools): + self.llm = llm.bind_tools(tools) def __call__(self, state): print("---CALL QueryWriter---") @@ -65,112 +58,20 @@ def create(cls, tools_descriptions): return ToolNode(tools_descriptions) -class DocumentGrader: - """Determines whether the retrieved documents are relevant to the question. - - Args: - state (messages): The current state - - Returns: - str: A decision for whether the documents are relevant or not - """ - - def __init__(self, llm_endpoint, model_id=None): - class grade(BaseModel): - """Binary score for relevance check.""" - - binary_score: str = Field(description="Relevance score 'yes' or 'no'") - - # Prompt - prompt = PromptTemplate( - template=DOC_GRADER_PROMPT, - input_variables=["context", "question"], - ) - - if isinstance(llm_endpoint, HuggingFaceEndpoint): - llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id).bind_tools([grade]) - elif isinstance(llm_endpoint, ChatOpenAI): - llm = llm_endpoint.bind_tools([grade]) - output_parser = PydanticToolsParser(tools=[grade], first_tool_only=True) - self.chain = prompt | llm | output_parser - - def __call__(self, state) -> Literal["generate", "rewrite"]: - print("---CALL DocumentGrader---") - messages = state["messages"] - last_message = messages[-1] # the latest retrieved doc - - question = messages[0].content # the original query - docs = last_message.content - - scored_result = self.chain.invoke({"question": question, "context": docs}) - - score = scored_result.binary_score - - if score.startswith("yes"): - print("---DECISION: DOCS RELEVANT---") - return {"doc_score": "generate"} - - else: - print(f"---DECISION: DOCS NOT RELEVANT, score is {score}---") - - return {"messages": [HumanMessage(content=instruction)], "doc_score": "rewrite"} - - -class TextGenerator: - """Generate answer. - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - - def __init__(self, llm_endpoint, model_id=None): - # Chain - # prompt = rlm_rag_prompt - prompt = RAG_PROMPT - self.rag_chain = prompt | llm_endpoint | StrOutputParser() - - def __call__(self, state): - from .utils import aggregate_docs - - print("---GENERATE---") - messages = state["messages"] - question = messages[0].content - query_time = state["query_time"] - - # find the latest retrieved doc - # which is a ToolMessage - # for m in state["messages"][::-1]: - # if isinstance(m, ToolMessage): - # last_message = m - # break - # docs = last_message.content - - question = messages[0].content - docs = aggregate_docs(messages) - - # Run - response = self.rag_chain.invoke({"context": docs, "question": question, "time": query_time}) - print("@@@@ Used this doc for generation:\n", docs) - print("@@@@ Generated response: ", response) - return {"messages": [response], "output": response} - - class RAGAgent(BaseAgent): def __init__(self, args, with_memory=False, **kwargs): super().__init__(args, local_vars=globals(), **kwargs) # Define Nodes - if args.strategy == "rag_agent": - query_writer = QueryWriter(self.llm_endpoint, args.model, self.tools_descriptions) - document_grader = DocumentGrader(self.llm_endpoint, args.model) + query_writer = QueryWriter(self.llm, self.tools_descriptions) elif args.strategy == "rag_agent_llama": - query_writer = QueryWriterLlama(self.llm_endpoint, args.model, self.tools_descriptions) - document_grader = DocumentGraderLlama(self.llm_endpoint, args.model) - text_generator = TextGenerator(self.llm_endpoint) + query_writer = QueryWriterLlama(args, self.tools_descriptions) + else: + raise ValueError("Only support 'rag_agent' or 'rag_agent_llama' strategy") + document_grader = DocumentGrader(args) + text_generator = TextGenerator(args) + retriever = Retriever.create(self.tools_descriptions) # Define graph @@ -225,7 +126,7 @@ def should_retry(self, state): return False def prepare_initial_state(self, query): - return {"messages": [HumanMessage(content=query)], "query_time": ""} + return {"messages": [HumanMessage(content=query)], "query_time": "", "output": "", "doc_score": ""} async def stream_generator(self, query, config): initial_state = self.prepare_initial_state(query) @@ -270,16 +171,19 @@ class QueryWriterLlama: Streaming=false is required for this chain. """ - def __init__(self, llm_endpoint, model_id, tools): + def __init__(self, args, tools): from .utils import QueryWriterLlamaOutputParser assert len(tools) == 1, "Only support one tool, passed in {} tools".format(len(tools)) + self.tools = tools + self.args = args + output_parser = QueryWriterLlamaOutputParser() prompt = PromptTemplate( template=QueryWriterLlamaPrompt, input_variables=["question", "history", "feedback"], ) - llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id) + llm = setup_chat_model(args) self.tools = tools self.chain = prompt | llm | output_parser @@ -296,18 +200,6 @@ def __call__(self, state): response = self.chain.invoke({"question": question, "history": history, "feedback": feedback}) print("Response from query writer llm: ", response) - ### Code below assumes one tool call in the response ############## - # if "query" in response: - # add_kw_tc, tool_call = convert_json_to_tool_call(response, self.tools[0]) - # # print("Tool call:\n", tool_call) - # response = AIMessage(content="", additional_kwargs=add_kw_tc, tool_calls=[tool_call]) - # # print(response) - # else: - # response = AIMessage(content=response["answer"]) - # We return a list, because this will get added to the existing list - # return {"messages": [response], "output": response} - ###################################################################### - ############ allow multiple tool calls in one AI message ############ tool_calls = [] for res in response: @@ -324,7 +216,7 @@ def __call__(self, state): return {"messages": [ai_message], "output": ai_message.content} -class DocumentGraderLlama: +class DocumentGrader: """Determines whether the retrieved documents are relevant to the question. Args: @@ -334,19 +226,12 @@ class DocumentGraderLlama: str: A decision for whether the documents are relevant or not """ - def __init__(self, llm_endpoint, model_id=None): - from .prompt import DOC_GRADER_Llama_PROMPT - - # Prompt + def __init__(self, args): prompt = PromptTemplate( - template=DOC_GRADER_Llama_PROMPT, + template=DOC_GRADER_PROMPT, input_variables=["context", "question"], ) - - if isinstance(llm_endpoint, HuggingFaceEndpoint): - llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id) - elif isinstance(llm_endpoint, ChatOpenAI): - llm = llm_endpoint + llm = setup_chat_model(args) self.chain = prompt | llm def __call__(self, state) -> Literal["generate", "rewrite"]: @@ -371,5 +256,37 @@ def __call__(self, state) -> Literal["generate", "rewrite"]: else: print("---DECISION: DOCS NOT RELEVANT---") - return {"messages": [HumanMessage(content=instruction)], "doc_score": "rewrite"} + + +class TextGenerator: + """Generate answer. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + + def __init__(self, args): + self.args = args + prompt = RAG_PROMPT + llm = setup_chat_model(args) + self.rag_chain = prompt | llm + + def __call__(self, state): + from .utils import aggregate_docs + + print("---GENERATE---") + messages = state["messages"] + question = messages[0].content + query_time = state["query_time"] + + question = messages[0].content + docs = aggregate_docs(messages) + + response = self.rag_chain.invoke({"context": docs, "question": question, "time": query_time}) + print("@@@@ Used this doc for generation:\n", docs) + print("@@@@ Generated response: ", response) + return {"messages": [response], "output": response} diff --git a/comps/agent/langchain/src/strategy/ragagent/prompt.py b/comps/agent/langchain/src/strategy/ragagent/prompt.py index e55e819ba..c2990af75 100644 --- a/comps/agent/langchain/src/strategy/ragagent/prompt.py +++ b/comps/agent/langchain/src/strategy/ragagent/prompt.py @@ -4,10 +4,10 @@ from langchain_core.prompts import ChatPromptTemplate DOC_GRADER_PROMPT = """\ -Given the QUERY, determine if a relevant answer can be derived from the DOCUMENT.\n +Given the QUERY, determine if the DOCUMENT contains all the information to answer the query.\n QUERY: {question} \n DOCUMENT:\n{context}\n\n -Give score 'yes' if the document provides sufficient and relevant information to answer the question. Otherwise, give score 'no'. ONLY answer with 'yes' or 'no'. NOTHING ELSE.""" +Give score 'yes' if the document provides all the information needed to answer the question. Otherwise, give score 'no'. ONLY answer with 'yes' or 'no'. NOTHING ELSE.""" PROMPT = """\ @@ -60,9 +60,3 @@ User Question: {question} You Output:\n """ - -DOC_GRADER_Llama_PROMPT = """\ -Given the QUERY, determine if the DOCUMENT contains all the information to answer the query.\n -QUERY: {question} \n -DOCUMENT:\n{context}\n\n -Give score 'yes' if the document provides all the information needed to answer the question. Otherwise, give score 'no'. ONLY answer with 'yes' or 'no'. NOTHING ELSE.""" diff --git a/comps/agent/langchain/src/strategy/ragagent/utils.py b/comps/agent/langchain/src/strategy/ragagent/utils.py index 1494ca29c..76eb605bc 100644 --- a/comps/agent/langchain/src/strategy/ragagent/utils.py +++ b/comps/agent/langchain/src/strategy/ragagent/utils.py @@ -14,17 +14,21 @@ class QueryWriterLlamaOutputParser(BaseOutputParser): def parse(self, text: str): print("raw output from llm: ", text) json_lines = text.split("\n") - print("json_lines: ", json_lines) output = [] for line in json_lines: try: - output.append(json.loads(line)) + if "assistant" in line: + line = line.replace("assistant", "") + print("line: ", line) + parsed = json.loads(line) + if isinstance(parsed, dict): + output.append(parsed) except Exception as e: print("Exception happened in output parsing: ", str(e)) if output: return output else: - return None + return "Error occurred when parsing LLM output." def convert_json_to_tool_call(json_str, tool): diff --git a/comps/agent/langchain/src/strategy/react/planner.py b/comps/agent/langchain/src/strategy/react/planner.py index ad471ad99..f574b5f65 100644 --- a/comps/agent/langchain/src/strategy/react/planner.py +++ b/comps/agent/langchain/src/strategy/react/planner.py @@ -11,7 +11,7 @@ from langgraph.prebuilt import create_react_agent from ...global_var import threads_global_kv -from ...utils import has_multi_tool_inputs, tool_renderer, wrap_chat +from ...utils import has_multi_tool_inputs, tool_renderer from ..base_agent import BaseAgent from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt @@ -24,7 +24,7 @@ def __init__(self, args, with_memory=False, **kwargs): raise ValueError("Only supports single input tools when using strategy == react_langchain") else: agent_chain = create_react_langchain_agent( - self.llm_endpoint, self.tools_descriptions, prompt, tools_renderer=tool_renderer + self.llm, self.tools_descriptions, prompt, tools_renderer=tool_renderer ) self.app = AgentExecutor( agent=agent_chain, tools=self.tools_descriptions, verbose=True, handle_parsing_errors=True @@ -84,8 +84,6 @@ class ReActAgentwithLanggraph(BaseAgent): def __init__(self, args, with_memory=False, **kwargs): super().__init__(args, local_vars=globals(), **kwargs) - self.llm = wrap_chat(self.llm_endpoint, args.model) - tools = self.tools_descriptions print("REACT_SYS_MESSAGE: ", REACT_SYS_MESSAGE) @@ -131,20 +129,24 @@ async def non_streaming_run(self, query, config): return str(e) +############################################################################### +# ReActAgentLlama: +# Only validated with with Llama3.1-70B-Instruct model served with TGI-gaudi +# support multiple tools +# does not rely on langchain bind_tools API +# since tgi and vllm still do not have very good support for tool calling like OpenAI + from typing import Annotated, Sequence, TypedDict from langchain_core.messages import AIMessage, BaseMessage from langchain_core.prompts import PromptTemplate from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages - -############################################################################### -# ReAct Agent: -# Temporary workaround for open-source LLM served by TGI-gaudi -# Only validated with with Llama3.1-70B-Instruct model served with TGI-gaudi from langgraph.managed import IsLastStep from langgraph.prebuilt import ToolNode +from ...utils import setup_chat_model + class AgentState(TypedDict): """The state of the agent.""" @@ -159,7 +161,7 @@ class ReActAgentNodeLlama: A workaround for open-source llm served by TGI-gaudi. """ - def __init__(self, llm_endpoint, model_id, tools, args): + def __init__(self, tools, args): from .prompt import REACT_AGENT_LLAMA_PROMPT from .utils import ReActLlamaOutputParser @@ -168,7 +170,7 @@ def __init__(self, llm_endpoint, model_id, tools, args): template=REACT_AGENT_LLAMA_PROMPT, input_variables=["input", "history", "tools"], ) - llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id) + llm = setup_chat_model(args) self.tools = tools self.chain = prompt | llm | output_parser @@ -201,7 +203,7 @@ def __call__(self, state): if tool_calls: ai_message = AIMessage(content="", additional_kwargs=add_kw_tc, tool_calls=tool_calls) elif "answer" in output[0]: - ai_message = AIMessage(content=output[0]["answer"]) + ai_message = AIMessage(content=str(output[0]["answer"])) else: ai_message = AIMessage(content=output) return {"messages": [ai_message]} @@ -210,9 +212,7 @@ def __call__(self, state): class ReActAgentLlama(BaseAgent): def __init__(self, args, with_memory=False, **kwargs): super().__init__(args, local_vars=globals(), **kwargs) - agent = ReActAgentNodeLlama( - llm_endpoint=self.llm_endpoint, model_id=args.model, tools=self.tools_descriptions, args=args - ) + agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args) tool_node = ToolNode(self.tools_descriptions) workflow = StateGraph(AgentState) diff --git a/comps/agent/langchain/src/strategy/react/prompt.py b/comps/agent/langchain/src/strategy/react/prompt.py index b81f3bef0..681c1325f 100644 --- a/comps/agent/langchain/src/strategy/react/prompt.py +++ b/comps/agent/langchain/src/strategy/react/prompt.py @@ -17,34 +17,43 @@ 3. Give concise, factual and relevant answers. """ - REACT_AGENT_LLAMA_PROMPT = """\ -Given the user request, think through the problem step by step. -Observe the outputs from the tools in the execution history, and think if you can come up with an answer or not. If yes, provide the answer. If not, make tool calls. -When you cannot get the answer at first, do not give up. Reflect on the steps you have taken so far and try to solve the problem in a different way. - -You have access to the following tools: +You are tasked with answering user questions. +You have the following tools to gather information: {tools} -Begin Execution History: -{history} -End Execution History. - -If you need to call tools, use the following format: -{{"tool":"tool 1", "args":{{"input 1": "input 1 value", "input 2": "input 2 value"}}}} -{{"tool":"tool 2", "args":{{"input 3": "input 3 value", "input 4": "input 4 value"}}}} -Multiple tools can be called in a single step, but always separate each tool call with a newline. +**Procedure:** +1. Read the question carefully. Divide the question into sub-questions and conquer sub-questions one by one. +2. Read the execution history if any to understand the tools that have been called and the information that has been gathered. +3. Reason about the information gathered so far and decide if you can answer the question or if you need to call more tools. -IMPORTANT: You MUST ALWAYS make tool calls unless you can provide an answer. Make each tool call in JSON format in a new line. +**Output format:** +You should output your thought process. Finish thinking first. Output tool calls or your answer at the end. +When making tool calls, you should use the following format: +TOOL CALL: {{"tool": "tool1", "args": {{"arg1": "value1", "arg2": "value2", ...}}}} +TOOL CALL: {{"tool": "tool2", "args": {{"arg1": "value1", "arg2": "value2", ...}}}} -If you can generate an answer, provide the answer in the following format in a new line: -{{"answer": "your answer here"}} +If you can answer the question, provide the answer in the following format: +FINAL ANSWER: {{"answer": "your answer here"}} Follow these guidelines when formulating your answer: 1. If the question contains a false premise or assumption, answer “invalid question”. 2. If you are uncertain or do not know the answer, answer “I don't know”. 3. Give concise, factual and relevant answers. -User request: {input} -Now begin! +**IMPORTANT:** +* Divide the question into sub-questions and conquer sub-questions one by one. +* Questions may be time sensitive. Pay attention to the time when the question was asked. +* You may need to combine information from multiple tools to answer the question. +* If you did not get the answer at first, do not give up. Reflect on the steps that you have taken and try a different way. Think out of the box. You hard work will be rewarded. +* Do not make up tool outputs. + +======= Your task ======= +Question: {input} + +Execution History: +{history} +======================== + +Now take a deep breath and think step by step to solve the problem. """ diff --git a/comps/agent/langchain/src/strategy/react/utils.py b/comps/agent/langchain/src/strategy/react/utils.py index 7b700ac15..f303b424a 100644 --- a/comps/agent/langchain/src/strategy/react/utils.py +++ b/comps/agent/langchain/src/strategy/react/utils.py @@ -14,19 +14,25 @@ class ReActLlamaOutputParser(BaseOutputParser): def parse(self, text: str): print("raw output from llm: ", text) json_lines = text.split("\n") - print("json_lines: ", json_lines) output = [] for line in json_lines: try: + if "TOOL CALL:" in line: + line = line.replace("TOOL CALL:", "") + if "FINAL ANSWER:" in line: + line = line.replace("FINAL ANSWER:", "") if "assistant" in line: line = line.replace("assistant", "") - output.append(json.loads(line)) + parsed_line = json.loads(line) + if isinstance(parsed_line, dict): + print("parsed line: ", parsed_line) + output.append(parsed_line) except Exception as e: print("Exception happened in output parsing: ", str(e)) if output: return output else: - return text # None + return text def convert_json_to_tool_call(json_str): @@ -46,12 +52,21 @@ def convert_json_to_tool_call(json_str): return add_kw_tc, tool_call +def get_tool_output(messages, id): + for msg in reversed(messages): + if isinstance(msg, ToolMessage): + if msg.tool_call_id == id: + tool_output = msg.content + break + return tool_output + + def assemble_history(messages): """ messages: AI, TOOL, AI, TOOL, etc. """ query_history = "" - n = 1 + breaker = "-" * 10 for m in messages[1:]: # exclude the first message if isinstance(m, AIMessage): # if there is tool call @@ -59,10 +74,11 @@ def assemble_history(messages): for tool_call in m.tool_calls: tool = tool_call["name"] tc_args = tool_call["args"] - query_history += f"Tool Call: {tool} - {tc_args}\n" + id = tool_call["id"] + tool_output = get_tool_output(messages, id) + query_history += f"Tool Call: {tool} - {tc_args}\nTool Output: {tool_output}\n{breaker}\n" else: # did not make tool calls - query_history += f"Assistant Output {n}: {m.content}\n" - elif isinstance(m, ToolMessage): - query_history += f"Tool Output: {m.content}\n" + query_history += f"Assistant Output: {m.content}\n" + return query_history diff --git a/comps/agent/langchain/src/utils.py b/comps/agent/langchain/src/utils.py index 3c121678f..fc1cde9ca 100644 --- a/comps/agent/langchain/src/utils.py +++ b/comps/agent/langchain/src/utils.py @@ -7,16 +7,6 @@ from .config import env_config -def wrap_chat(llm_endpoint, model_id): - from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint - - if isinstance(llm_endpoint, HuggingFaceEndpoint): - llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id) - else: - llm = llm_endpoint - return llm - - def format_date(date): # input m/dd/yyyy hr:min # output yyyy-mm-dd @@ -49,54 +39,30 @@ def setup_hf_tgi_client(args): } llm = HuggingFaceEndpoint( - endpoint_url=args.llm_endpoint_url, ## endpoint_url = "localhost:8080", + endpoint_url=args.llm_endpoint_url, task="text-generation", **generation_params, ) return llm -def setup_vllm_client(args): - from langchain_openai import ChatOpenAI - - openai_endpoint = f"{args.llm_endpoint_url}/v1" - params = { - "temperature": args.temperature, - "max_tokens": args.max_new_tokens, - "streaming": args.streaming, - } - llm = ChatOpenAI(openai_api_key="EMPTY", openai_api_base=openai_endpoint, model_name=args.model, **params) - return llm - - -def setup_openai_client(args): - """Lower values for temperature result in more consistent outputs (e.g. 0.2), - while higher values generate more diverse and creative results (e.g. 1.0). - - Select a temperature value based on the desired trade-off between coherence - and creativity for your specific application. The temperature can range is from 0 to 2. - """ +def setup_chat_model(args): from langchain_openai import ChatOpenAI params = { "temperature": args.temperature, "max_tokens": args.max_new_tokens, + "top_p": args.top_p, "streaming": args.streaming, } - llm = ChatOpenAI(model_name=args.model, **params) - return llm - - -def setup_llm(args): - if args.llm_engine == "vllm": - model = setup_vllm_client(args) - elif args.llm_engine == "tgi": - model = setup_hf_tgi_client(args) + if args.llm_engine == "vllm" or args.llm_engine == "tgi": + openai_endpoint = f"{args.llm_endpoint_url}/v1" + llm = ChatOpenAI(openai_api_key="EMPTY", openai_api_base=openai_endpoint, model_name=args.model, **params) elif args.llm_engine == "openai": - model = setup_openai_client(args) + llm = ChatOpenAI(model_name=args.model, **params) else: - raise ValueError("Only supports vllm or hf_tgi mode for now") - return model + raise ValueError("llm_engine must be vllm, tgi or openai") + return llm def tool_renderer(tools): diff --git a/comps/agent/langchain/tools/custom_tools.py b/comps/agent/langchain/tools/custom_tools.py new file mode 100644 index 000000000..d87a99374 --- /dev/null +++ b/comps/agent/langchain/tools/custom_tools.py @@ -0,0 +1,12 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +# tool for unit test +def search_web(query: str) -> str: + """Search the web for a given query.""" + ret_text = """ + The Linux Foundation AI & Data announced the Open Platform for Enterprise AI (OPEA) as its latest Sandbox Project. + OPEA aims to accelerate secure, cost-effective generative AI (GenAI) deployments for businesses by driving interoperability across a diverse and heterogeneous ecosystem, starting with retrieval-augmented generation (RAG). + """ + return ret_text diff --git a/comps/agent/langchain/tools/custom_tools.yaml b/comps/agent/langchain/tools/custom_tools.yaml index 905106ee3..86df07634 100644 --- a/comps/agent/langchain/tools/custom_tools.yaml +++ b/comps/agent/langchain/tools/custom_tools.yaml @@ -1,5 +1,11 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -duckduckgo_search: - callable_api: ddg-search +search_knowledge_base: + description: Search the web for a given query. + callable_api: custom_tools.py:search_web + args_schema: + query: + type: str + description: query + return_output: retrieved_data diff --git a/tests/agent/planexec_openai.yaml b/tests/agent/planexec_openai.yaml new file mode 100644 index 000000000..e1e92dfd2 --- /dev/null +++ b/tests/agent/planexec_openai.yaml @@ -0,0 +1,28 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +services: + react-agent: + image: ${agent_image} + container_name: test-comps-agent-endpoint + volumes: + - ${TOOLSET_PATH}:/home/user/tools/ + ports: + - "9095:9095" + ipc: host + environment: + ip_address: ${ip_address} + strategy: plan_execute + recursion_limit: ${recursion_limit} + llm_engine: openai + OPENAI_API_KEY: ${OPENAI_API_KEY} + model: "gpt-4o-mini" + temperature: ${temperature} + max_new_tokens: ${max_new_tokens} + streaming: false + tools: /home/user/tools/custom_tools.yaml + require_human_feedback: false + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + port: 9095 diff --git a/tests/agent/ragagent_openai.yaml b/tests/agent/ragagent_openai.yaml new file mode 100644 index 000000000..13fb32d8e --- /dev/null +++ b/tests/agent/ragagent_openai.yaml @@ -0,0 +1,28 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +services: + rag-agent: + image: ${agent_image} + container_name: test-comps-agent-endpoint + volumes: + - ${TOOLSET_PATH}:/home/user/tools/ + ports: + - "9095:9095" + ipc: host + environment: + ip_address: ${ip_address} + strategy: rag_agent + recursion_limit: ${recursion_limit} + llm_engine: openai + OPENAI_API_KEY: ${OPENAI_API_KEY} + model: "gpt-4o-mini" + temperature: ${temperature} + max_new_tokens: ${max_new_tokens} + streaming: false + tools: /home/user/tools/custom_tools.yaml + require_human_feedback: false + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + port: 9095 diff --git a/tests/agent/react_langgraph_openai.yaml b/tests/agent/react_langgraph_openai.yaml new file mode 100644 index 000000000..6afe2fb2c --- /dev/null +++ b/tests/agent/react_langgraph_openai.yaml @@ -0,0 +1,28 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +services: + react-agent: + image: ${agent_image} + container_name: test-comps-agent-endpoint + volumes: + - ${TOOLSET_PATH}:/home/user/tools/ + ports: + - "9095:9095" + ipc: host + environment: + ip_address: ${ip_address} + strategy: react_langgraph + recursion_limit: ${recursion_limit} + llm_engine: openai + OPENAI_API_KEY: ${OPENAI_API_KEY} + model: "gpt-4o-mini" + temperature: ${temperature} + max_new_tokens: ${max_new_tokens} + streaming: false + tools: /home/user/tools/custom_tools.yaml + require_human_feedback: false + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + port: 9095 diff --git a/tests/agent/test.py b/tests/agent/test.py new file mode 100644 index 000000000..fdbfe1c5b --- /dev/null +++ b/tests/agent/test.py @@ -0,0 +1,57 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os + +import requests + + +def generate_answer_agent_api(url, prompt): + proxies = {"http": ""} + payload = { + "query": prompt, + } + response = requests.post(url, json=payload, proxies=proxies) + answer = response.json()["text"] + return answer + + +def process_request(url, query, is_stream=False): + proxies = {"http": ""} + + payload = { + "query": query, + } + + try: + resp = requests.post(url=url, json=payload, proxies=proxies, stream=is_stream) + if not is_stream: + ret = resp.json()["text"] + print(ret) + else: + for line in resp.iter_lines(decode_unicode=True): + print(line) + ret = None + + resp.raise_for_status() # Raise an exception for unsuccessful HTTP status codes + return ret + except requests.exceptions.RequestException as e: + ret = f"An error occurred:{e}" + print(ret) + return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--stream", action="store_true", help="Stream the response") + args = parser.parse_args() + + ip_address = os.getenv("ip_address", "localhost") + url = f"http://{ip_address}:9095/v1/chat/completions" + prompt = "What is OPEA?" + if args.stream: + process_request(url, prompt, is_stream=True) + else: + answer = generate_answer_agent_api(url, prompt) + print(answer) diff --git a/tests/agent/test_agent_langchain_on_intel_hpu.sh b/tests/agent/test_agent_langchain_on_intel_hpu.sh index 5f35fb74d..4cc36164e 100644 --- a/tests/agent/test_agent_langchain_on_intel_hpu.sh +++ b/tests/agent/test_agent_langchain_on_intel_hpu.sh @@ -5,6 +5,7 @@ #set -xe WORKPATH=$(dirname "$PWD") +echo $WORKPATH LOG_PATH="$WORKPATH/tests" ip_address=$(hostname -I | awk '{print $1}') tgi_port=8085 @@ -141,7 +142,16 @@ function start_react_langchain_agent_service() { } -function start_react_langgraph_agent_service() { +function start_react_langgraph_agent_service_openai() { + echo "Starting react_langchain agent microservice" + docker compose -f $WORKPATH/tests/agent/react_langgraph_openai.yaml up -d + sleep 5s + docker logs test-comps-agent-endpoint + echo "Service started successfully" +} + + +function start_react_llama_agent_service() { echo "Starting react_langgraph agent microservice" docker compose -f $WORKPATH/tests/agent/reactllama.yaml up -d sleep 5s @@ -165,28 +175,37 @@ function start_planexec_agent_service_vllm() { echo "Service started successfully" } -function start_react_langgraph_agent_service_openai() { - echo "Starting react_langgraph agent microservice" - docker run -d --runtime=runc --name="test-comps-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 9095:9095 --ipc=host -e model=gpt-4o-mini-2024-07-18 -e strategy=react_langgraph -e llm_engine=openai -e OPENAI_API_KEY=${OPENAI_API_KEY} -e recursion_limit=10 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/agent-langchain:comps +function start_ragagent_agent_service() { + echo "Starting rag agent microservice" + docker compose -f $WORKPATH/tests/agent/ragagent.yaml up -d sleep 5s docker logs test-comps-agent-endpoint echo "Service started successfully" } - -function start_ragagent_agent_service() { +function start_ragagent_agent_service_openai() { echo "Starting rag agent microservice" - docker compose -f $WORKPATH/tests/agent/ragagent.yaml up -d + docker compose -f $WORKPATH/tests/agent/ragagent_openai.yaml up -d sleep 5s docker logs test-comps-agent-endpoint echo "Service started successfully" } +function start_planexec_agent_service_openai() { + echo "Starting plan execute agent microservice" + docker compose -f $WORKPATH/tests/agent/planexec_openai.yaml up -d + sleep 5s + docker logs test-comps-agent-endpoint + echo "Service started successfully" +} function validate() { local CONTENT="$1" local EXPECTED_RESULT="$2" local SERVICE_NAME="$3" + # local CONTENT_TO_VALIDATE= "$CONTENT" | grep -oP '(?<=text:).*?(?=prompt)' + echo "Content: $CONTENT" + # echo "Content to validate: $CONTENT_TO_VALIDATE" if echo "$CONTENT" | grep -q "$EXPECTED_RESULT"; then echo "[ $SERVICE_NAME ] Content is as expected: $CONTENT" @@ -199,9 +218,27 @@ function validate() { function validate_microservice() { echo "Testing agent service - chat completion API" - local CONTENT=$(http_proxy="" curl http://${ip_address}:9095/v1/chat/completions -X POST -H "Content-Type: application/json" -d '{ - "query": "What is Intel OPEA project?" - }') + # local CONTENT=$(http_proxy="" curl http://${ip_address}:9095/v1/chat/completions -X POST -H "Content-Type: application/json" -d '{ + # "query": "What is OPEA?" + # }') + CONTENT=$(python3 $WORKPATH/tests/agent/test.py) + local EXIT_CODE=$(validate "$CONTENT" "OPEA" "test-agent-langchain") + echo "$EXIT_CODE" + local EXIT_CODE="${EXIT_CODE:0-1}" + echo "return value is $EXIT_CODE" + if [ "$EXIT_CODE" == "1" ]; then + echo "==================TGI logs ======================" + docker logs test-comps-tgi-gaudi-service + echo "==================Agent logs ======================" + docker logs test-comps-agent-endpoint + exit 1 + fi +} + + +function validate_microservice_streaming() { + echo "Testing agent service - chat completion API" + CONTENT=$(python3 $WORKPATH/tests/agent/test.py --stream) local EXIT_CODE=$(validate "$CONTENT" "OPEA" "test-agent-langchain") echo "$EXIT_CODE" local EXIT_CODE="${EXIT_CODE:0-1}" @@ -275,7 +312,7 @@ function main() { echo "=============================================" # test react_llama - start_react_langgraph_agent_service + start_react_llama_agent_service echo "===========Testing ReAct Llama =============" validate_microservice stop_agent_docker @@ -285,7 +322,7 @@ function main() { # test react_langchain start_react_langchain_agent_service echo "=============Testing ReAct Langchain=============" - validate_microservice + validate_microservice_streaming validate_assistant_api stop_agent_docker echo "=============================================" @@ -300,19 +337,19 @@ function main() { export model_parser=mistral export LLM_ENDPOINT_URL="http://${ip_address}:${vllm_port}" - # test react with vllm + # test react with vllm - Mistral start_vllm_auto_tool_choice_service start_react_langgraph_agent_service_vllm - echo "===========Testing ReAct VLLM =============" + echo "===========Testing ReAct Langgraph VLLM Mistral =============" validate_microservice - stop_agent_docker - stop_vllm_docker + # stop_agent_docker + # stop_vllm_docker echo "=============================================" - # test plan execute with vllm + # test plan execute with vllm - Mistral start_vllm_service start_planexec_agent_service_vllm - echo "===========Testing Plan Execute VLLM =============" + echo "===========Testing Plan Execute VLLM Mistral =============" validate_microservice stop_agent_docker stop_vllm_docker @@ -331,15 +368,34 @@ function main() { # stop_vllm_docker # echo "=============================================" - # test plan execute with vllm + # test plan execute with vllm - llama3.1 start_vllm_service start_planexec_agent_service_vllm - echo "===========Testing Plan Execute VLLM =============" + echo "===========Testing Plan Execute VLLM Llama3.1 =============" validate_microservice stop_agent_docker stop_vllm_docker echo "=============================================" + + # # ==================== OpenAI tests ==================== + # start_ragagent_agent_service_openai + # echo "=============Testing RAG Agent OpenAI=============" + # validate_microservice + # stop_agent_docker + # echo "=============================================" + + # start_react_langgraph_agent_service_openai + # echo "===========Testing ReAct Langgraph OpenAI =============" + # validate_microservice + # stop_agent_docker + # echo "=============================================" + + # start_planexec_agent_service_openai + # echo "===========Testing Plan Execute OpenAI =============" + # validate_microservice + # stop_agent_docker + stop_docker echo y | docker system prune 2>&1 > /dev/null }