Skip to content

Commit

Permalink
agent short & long term memory with langgraph. (#851)
Browse files Browse the repository at this point in the history
* draft a demo code for memory.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add agent short-term memory with langgraph checkpoint.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add save long-term memory func.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add save long-term memory func.

* add timeout for llm response.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ut with adding -e HABANA_VISIBLE_DEVICES=all.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lkk12014402 and pre-commit-ci[bot] authored Nov 12, 2024
1 parent 24b9f03 commit e39b08f
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 20 deletions.
37 changes: 23 additions & 14 deletions comps/agent/langchain/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@

args, _ = get_args()

logger.info("========initiating agent============")
logger.info(f"args: {args}")
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)


class AgentCompletionRequest(LLMParamsDoc):
thread_id: str = "0"
user_id: str = "0"


@register_microservice(
name="opea_service@comps-chat-agent",
Expand All @@ -43,16 +52,19 @@
host="0.0.0.0",
port=args.port,
)
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]):
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, AgentCompletionRequest]):
if logflag:
logger.info(input)
# 1. initialize the agent
if logflag:
logger.info(f"args: {args}")

input.streaming = args.streaming
config = {"recursion_limit": args.recursion_limit}
print("========initiating agent============")
agent_inst = instantiate_agent(args, args.strategy)

if args.with_memory:
if isinstance(input, AgentCompletionRequest):
config["configurable"] = {"thread_id": input.thread_id}
else:
config["configurable"] = {"thread_id": "0"}

if logflag:
logger.info(type(agent_inst))

Expand All @@ -68,14 +80,13 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]):

# 2. prepare the input for the agent
if input.streaming:
print("-----------STREAMING-------------")
logger.info("-----------STREAMING-------------")
return StreamingResponse(agent_inst.stream_generator(input_query, config), media_type="text/event-stream")

else:
print("-----------NOT STREAMING-------------")
logger.info("-----------NOT STREAMING-------------")
response = await agent_inst.non_streaming_run(input_query, config)
print("-----------Response-------------")
print(response)
logger.info("-----------Response-------------")
return GeneratedDoc(text=response, prompt=input_query)


Expand All @@ -87,13 +98,11 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]):
)
def create_assistants(input: CreateAssistantsRequest):
# 1. initialize the agent
print("args: ", args)
agent_inst = instantiate_agent(args, args.strategy, with_memory=True)
agent_id = agent_inst.id
created_at = int(datetime.now().timestamp())
with assistants_global_kv as g_assistants:
g_assistants[agent_id] = (agent_inst, created_at)
print(f"Record assistant inst {agent_id} in global KV")
logger.info(f"Record assistant inst {agent_id} in global KV")

# get current time in string format
return AssistantsObject(
Expand All @@ -115,7 +124,7 @@ def create_threads(input: CreateThreadsRequest):
status = "ready"
with threads_global_kv as g_threads:
g_threads[thread_id] = (thread_inst, created_at, status)
print(f"Record thread inst {thread_id} in global KV")
logger.info(f"Record thread inst {thread_id} in global KV")

return ThreadObject(
id=thread_id,
Expand Down
9 changes: 9 additions & 0 deletions comps/agent/langchain/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@

if os.environ.get("custom_prompt") is not None:
env_config += ["--custom_prompt", os.environ["custom_prompt"]]

if os.environ.get("with_memory") is not None:
env_config += ["--with_memory", os.environ["with_memory"]]

if os.environ.get("with_store") is not None:
env_config += ["--with_store", os.environ["with_store"]]

if os.environ.get("timeout") is not None:
env_config += ["--timeout", os.environ["timeout"]]
68 changes: 68 additions & 0 deletions comps/agent/langchain/src/persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import uuid
from datetime import datetime
from typing import List, Optional

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph
from langgraph.store.memory import InMemoryStore
from pydantic import BaseModel


class PersistenceConfig(BaseModel):
checkpointer: bool = False
store: bool = False


class PersistenceInfo(BaseModel):
user_id: str = None
thread_id: str = None
started_at: datetime


class AgentPersistence:
def __init__(self, config: PersistenceConfig):
# for short-term memory
self.checkpointer = None
# for long-term memory
self.store = None
self.config = config
print(f"Initializing AgentPersistence: {config}")
self.initialize()

def initialize(self) -> None:
if self.config.checkpointer:
self.checkpointer = MemorySaver()
if self.config.store:
self.store = InMemoryStore()

def save(
self,
config: RunnableConfig,
content: str,
context: str,
memory_id: Optional[str] = None,
):
"""This function is only for long-term memory."""
mem_id = memory_id or uuid.uuid4()
user_id = config["configurable"]["user_id"]
self.store.put(
("memories", user_id),
key=str(mem_id),
value={"content": content, "context": context},
)
return f"Stored memory {content}"

def get(self, config: RunnableConfig):
"""This function is only for long-term memory."""
user_id = config["configurable"]["user_id"]
namespace = ("memories", user_id)
memories = self.store.search(namespace)
return memories

def update_state(self, config, graph: StateGraph):
pass
9 changes: 7 additions & 2 deletions comps/agent/langchain/src/strategy/react/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ async def non_streaming_run(self, query, config):
from langgraph.managed import IsLastStep
from langgraph.prebuilt import ToolNode

from ...persistence import AgentPersistence, PersistenceConfig
from ...utils import setup_chat_model


Expand Down Expand Up @@ -248,8 +249,12 @@ def __init__(self, args, with_memory=False, **kwargs):
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", "agent")

if with_memory:
self.app = workflow.compile(checkpointer=MemorySaver())
if args.with_memory:
self.persistence = AgentPersistence(
config=PersistenceConfig(checkpointer=args.with_memory, store=args.with_store)
)
print(self.persistence.checkpointer)
self.app = workflow.compile(checkpointer=self.persistence.checkpointer, store=self.persistence.store)
else:
self.app = workflow.compile()

Expand Down
13 changes: 11 additions & 2 deletions comps/agent/langchain/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,15 @@ def setup_chat_model(args):
}
if args.llm_engine == "vllm" or args.llm_engine == "tgi":
openai_endpoint = f"{args.llm_endpoint_url}/v1"
llm = ChatOpenAI(openai_api_key="EMPTY", openai_api_base=openai_endpoint, model_name=args.model, **params)
llm = ChatOpenAI(
openai_api_key="EMPTY",
openai_api_base=openai_endpoint,
model_name=args.model,
request_timeout=args.timeout,
**params,
)
elif args.llm_engine == "openai":
llm = ChatOpenAI(model_name=args.model, **params)
llm = ChatOpenAI(model_name=args.model, request_timeout=args.timeout, **params)
else:
raise ValueError("llm_engine must be vllm, tgi or openai")
return llm
Expand Down Expand Up @@ -129,6 +135,9 @@ def get_args():
parser.add_argument("--repetition_penalty", type=float, default=1.03)
parser.add_argument("--return_full_text", type=bool, default=False)
parser.add_argument("--custom_prompt", type=str, default=None)
parser.add_argument("--with_memory", type=bool, default=False)
parser.add_argument("--with_store", type=bool, default=False)
parser.add_argument("--timeout", type=int, default=60)

sys_args, unknown_args = parser.parse_known_args()
# print("env_config: ", env_config)
Expand Down
4 changes: 2 additions & 2 deletions tests/agent/test_agent_langchain_on_intel_hpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function start_vllm_service() {

#single card
echo "start vllm gaudi service"
docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm:hpu --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192
docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=all -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm:hpu --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192
sleep 5s
echo "Waiting vllm gaudi ready"
n=0
Expand All @@ -113,7 +113,7 @@ function start_vllm_auto_tool_choice_service() {

#single card
echo "start vllm gaudi service"
docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm:hpu --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192 --enable-auto-tool-choice --tool-call-parser ${model_parser}
docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=all -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm:hpu --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192 --enable-auto-tool-choice --tool-call-parser ${model_parser}
sleep 5s
echo "Waiting vllm gaudi ready"
n=0
Expand Down

0 comments on commit e39b08f

Please sign in to comment.