From e39b08f3d04bffc3e23da155992c1c6ed845c8e4 Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Tue, 12 Nov 2024 17:28:37 +0800 Subject: [PATCH] agent short & long term memory with langgraph. (#851) * draft a demo code for memory. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add agent short-term memory with langgraph checkpoint. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add save long-term memory func. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add save long-term memory func. * add timeout for llm response. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ut with adding -e HABANA_VISIBLE_DEVICES=all. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/agent/langchain/agent.py | 37 ++++++---- comps/agent/langchain/src/config.py | 9 +++ comps/agent/langchain/src/persistence.py | 68 +++++++++++++++++++ .../langchain/src/strategy/react/planner.py | 9 ++- comps/agent/langchain/src/utils.py | 13 +++- .../test_agent_langchain_on_intel_hpu.sh | 4 +- 6 files changed, 120 insertions(+), 20 deletions(-) create mode 100644 comps/agent/langchain/src/persistence.py diff --git a/comps/agent/langchain/agent.py b/comps/agent/langchain/agent.py index 7eb44bc00..8eb87d146 100644 --- a/comps/agent/langchain/agent.py +++ b/comps/agent/langchain/agent.py @@ -35,6 +35,15 @@ args, _ = get_args() +logger.info("========initiating agent============") +logger.info(f"args: {args}") +agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory) + + +class AgentCompletionRequest(LLMParamsDoc): + thread_id: str = "0" + user_id: str = "0" + @register_microservice( name="opea_service@comps-chat-agent", @@ -43,16 +52,19 @@ host="0.0.0.0", port=args.port, ) -async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]): +async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, AgentCompletionRequest]): if logflag: logger.info(input) - # 1. initialize the agent - if logflag: - logger.info(f"args: {args}") + input.streaming = args.streaming config = {"recursion_limit": args.recursion_limit} - print("========initiating agent============") - agent_inst = instantiate_agent(args, args.strategy) + + if args.with_memory: + if isinstance(input, AgentCompletionRequest): + config["configurable"] = {"thread_id": input.thread_id} + else: + config["configurable"] = {"thread_id": "0"} + if logflag: logger.info(type(agent_inst)) @@ -68,14 +80,13 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]): # 2. prepare the input for the agent if input.streaming: - print("-----------STREAMING-------------") + logger.info("-----------STREAMING-------------") return StreamingResponse(agent_inst.stream_generator(input_query, config), media_type="text/event-stream") else: - print("-----------NOT STREAMING-------------") + logger.info("-----------NOT STREAMING-------------") response = await agent_inst.non_streaming_run(input_query, config) - print("-----------Response-------------") - print(response) + logger.info("-----------Response-------------") return GeneratedDoc(text=response, prompt=input_query) @@ -87,13 +98,11 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]): ) def create_assistants(input: CreateAssistantsRequest): # 1. initialize the agent - print("args: ", args) - agent_inst = instantiate_agent(args, args.strategy, with_memory=True) agent_id = agent_inst.id created_at = int(datetime.now().timestamp()) with assistants_global_kv as g_assistants: g_assistants[agent_id] = (agent_inst, created_at) - print(f"Record assistant inst {agent_id} in global KV") + logger.info(f"Record assistant inst {agent_id} in global KV") # get current time in string format return AssistantsObject( @@ -115,7 +124,7 @@ def create_threads(input: CreateThreadsRequest): status = "ready" with threads_global_kv as g_threads: g_threads[thread_id] = (thread_inst, created_at, status) - print(f"Record thread inst {thread_id} in global KV") + logger.info(f"Record thread inst {thread_id} in global KV") return ThreadObject( id=thread_id, diff --git a/comps/agent/langchain/src/config.py b/comps/agent/langchain/src/config.py index 9d4cf3574..4178e2d9f 100644 --- a/comps/agent/langchain/src/config.py +++ b/comps/agent/langchain/src/config.py @@ -63,3 +63,12 @@ if os.environ.get("custom_prompt") is not None: env_config += ["--custom_prompt", os.environ["custom_prompt"]] + +if os.environ.get("with_memory") is not None: + env_config += ["--with_memory", os.environ["with_memory"]] + +if os.environ.get("with_store") is not None: + env_config += ["--with_store", os.environ["with_store"]] + +if os.environ.get("timeout") is not None: + env_config += ["--timeout", os.environ["timeout"]] diff --git a/comps/agent/langchain/src/persistence.py b/comps/agent/langchain/src/persistence.py new file mode 100644 index 000000000..43fafb915 --- /dev/null +++ b/comps/agent/langchain/src/persistence.py @@ -0,0 +1,68 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import uuid +from datetime import datetime +from typing import List, Optional + +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import StateGraph +from langgraph.store.memory import InMemoryStore +from pydantic import BaseModel + + +class PersistenceConfig(BaseModel): + checkpointer: bool = False + store: bool = False + + +class PersistenceInfo(BaseModel): + user_id: str = None + thread_id: str = None + started_at: datetime + + +class AgentPersistence: + def __init__(self, config: PersistenceConfig): + # for short-term memory + self.checkpointer = None + # for long-term memory + self.store = None + self.config = config + print(f"Initializing AgentPersistence: {config}") + self.initialize() + + def initialize(self) -> None: + if self.config.checkpointer: + self.checkpointer = MemorySaver() + if self.config.store: + self.store = InMemoryStore() + + def save( + self, + config: RunnableConfig, + content: str, + context: str, + memory_id: Optional[str] = None, + ): + """This function is only for long-term memory.""" + mem_id = memory_id or uuid.uuid4() + user_id = config["configurable"]["user_id"] + self.store.put( + ("memories", user_id), + key=str(mem_id), + value={"content": content, "context": context}, + ) + return f"Stored memory {content}" + + def get(self, config: RunnableConfig): + """This function is only for long-term memory.""" + user_id = config["configurable"]["user_id"] + namespace = ("memories", user_id) + memories = self.store.search(namespace) + return memories + + def update_state(self, config, graph: StateGraph): + pass diff --git a/comps/agent/langchain/src/strategy/react/planner.py b/comps/agent/langchain/src/strategy/react/planner.py index f574b5f65..9771f6220 100644 --- a/comps/agent/langchain/src/strategy/react/planner.py +++ b/comps/agent/langchain/src/strategy/react/planner.py @@ -145,6 +145,7 @@ async def non_streaming_run(self, query, config): from langgraph.managed import IsLastStep from langgraph.prebuilt import ToolNode +from ...persistence import AgentPersistence, PersistenceConfig from ...utils import setup_chat_model @@ -248,8 +249,12 @@ def __init__(self, args, with_memory=False, **kwargs): # This means that after `tools` is called, `agent` node is called next. workflow.add_edge("tools", "agent") - if with_memory: - self.app = workflow.compile(checkpointer=MemorySaver()) + if args.with_memory: + self.persistence = AgentPersistence( + config=PersistenceConfig(checkpointer=args.with_memory, store=args.with_store) + ) + print(self.persistence.checkpointer) + self.app = workflow.compile(checkpointer=self.persistence.checkpointer, store=self.persistence.store) else: self.app = workflow.compile() diff --git a/comps/agent/langchain/src/utils.py b/comps/agent/langchain/src/utils.py index fc1cde9ca..e8a317a5d 100644 --- a/comps/agent/langchain/src/utils.py +++ b/comps/agent/langchain/src/utils.py @@ -57,9 +57,15 @@ def setup_chat_model(args): } if args.llm_engine == "vllm" or args.llm_engine == "tgi": openai_endpoint = f"{args.llm_endpoint_url}/v1" - llm = ChatOpenAI(openai_api_key="EMPTY", openai_api_base=openai_endpoint, model_name=args.model, **params) + llm = ChatOpenAI( + openai_api_key="EMPTY", + openai_api_base=openai_endpoint, + model_name=args.model, + request_timeout=args.timeout, + **params, + ) elif args.llm_engine == "openai": - llm = ChatOpenAI(model_name=args.model, **params) + llm = ChatOpenAI(model_name=args.model, request_timeout=args.timeout, **params) else: raise ValueError("llm_engine must be vllm, tgi or openai") return llm @@ -129,6 +135,9 @@ def get_args(): parser.add_argument("--repetition_penalty", type=float, default=1.03) parser.add_argument("--return_full_text", type=bool, default=False) parser.add_argument("--custom_prompt", type=str, default=None) + parser.add_argument("--with_memory", type=bool, default=False) + parser.add_argument("--with_store", type=bool, default=False) + parser.add_argument("--timeout", type=int, default=60) sys_args, unknown_args = parser.parse_known_args() # print("env_config: ", env_config) diff --git a/tests/agent/test_agent_langchain_on_intel_hpu.sh b/tests/agent/test_agent_langchain_on_intel_hpu.sh index 4cc36164e..04da54285 100644 --- a/tests/agent/test_agent_langchain_on_intel_hpu.sh +++ b/tests/agent/test_agent_langchain_on_intel_hpu.sh @@ -87,7 +87,7 @@ function start_vllm_service() { #single card echo "start vllm gaudi service" - docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm:hpu --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192 + docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=all -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm:hpu --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192 sleep 5s echo "Waiting vllm gaudi ready" n=0 @@ -113,7 +113,7 @@ function start_vllm_auto_tool_choice_service() { #single card echo "start vllm gaudi service" - docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm:hpu --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192 --enable-auto-tool-choice --tool-call-parser ${model_parser} + docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=all -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm:hpu --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192 --enable-auto-tool-choice --tool-call-parser ${model_parser} sleep 5s echo "Waiting vllm gaudi ready" n=0