Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 19, 2024
1 parent 27a0510 commit 71e2658
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 61 deletions.
56 changes: 43 additions & 13 deletions comps/agent/langchain/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,36 @@
import os
import pathlib
import sys
from datetime import datetime
from typing import Union

from fastapi.responses import StreamingResponse
from datetime import datetime

cur_path = pathlib.Path(__file__).parent.resolve()
comps_path = os.path.join(cur_path, "../../../")
sys.path.append(comps_path)

from comps import GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
from comps.cores.proto.api_protocol import ChatCompletionRequest, CreateAssistantsRequest, AssistantsObject, ThreadObject, MessageObject, RunObject, MessageContent, CreateThreadsRequest, CreateMessagesRequest, CreateRunResponse
from comps.agent.langchain.src.agent import instantiate_agent
from comps.agent.langchain.src.thread import instantiate_thread_memory,thread_completion_callback
from comps.agent.langchain.src.utils import get_args
from comps.agent.langchain.src.global_var import assistants_global_kv, threads_global_kv
from comps.agent.langchain.src.thread import instantiate_thread_memory, thread_completion_callback
from comps.agent.langchain.src.utils import get_args
from comps.cores.proto.api_protocol import (
AssistantsObject,
ChatCompletionRequest,
CreateAssistantsRequest,
CreateMessagesRequest,
CreateRunResponse,
CreateThreadsRequest,
MessageContent,
MessageObject,
RunObject,
ThreadObject,
)

args, _ = get_args()


@register_microservice(
name="opea_service@comps-chat-agent",
service_type=ServiceType.LLM,
Expand All @@ -47,7 +59,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest]):
input_query = input.messages
else:
input_query = input.messages[-1]["content"]

# 2. prepare the input for the agent
if input.streaming:
print("-----------STREAMING-------------")
Expand Down Expand Up @@ -78,7 +90,10 @@ def create_assistants(input: CreateAssistantsRequest):
print(f"Record assistant inst {agent_id} in global KV")

# get current time in string format
return AssistantsObject(id=agent_id, created_at=created_at,)
return AssistantsObject(
id=agent_id,
created_at=created_at,
)


@register_microservice(
Expand All @@ -96,7 +111,10 @@ def create_threads(input: CreateThreadsRequest):
g_threads[thread_id] = (thread_inst, created_at, status)
print(f"Record thread inst {thread_id} in global KV")

return ThreadObject(id=thread_id, created_at=created_at,)
return ThreadObject(
id=thread_id,
created_at=created_at,
)


@register_microservice(
Expand All @@ -116,9 +134,16 @@ def create_messages(thread_id, input: CreateMessagesRequest):
else:
query = input.content[-1]["text"]
msg_id, created_at = thread_inst.add_query(query)

structured_content = MessageContent(text=query)
return MessageObject(id=msg_id, created_at=created_at, thread_id=thread_id, role=role, content=[structured_content],)
return MessageObject(
id=msg_id,
created_at=created_at,
thread_id=thread_id,
role=role,
content=[structured_content],
)


@register_microservice(
name="opea_service@comps-chat-agent",
Expand All @@ -129,24 +154,28 @@ def create_messages(thread_id, input: CreateMessagesRequest):
def create_run(thread_id, input: CreateRunResponse):
with threads_global_kv as g_threads:
thread_inst, _, status = g_threads[thread_id]

if status == "running":
return "[error] Thread is already running, need to cancel the current run or wait for it to finish"

agent_id = input.assistant_id
with assistants_global_kv as g_assistants:
agent_inst, _ = g_assistants[agent_id]

config = {"recursion_limit": args.recursion_limit}
input_query = thread_inst.get_query()
try:
return StreamingResponse(thread_completion_callback(agent_inst.stream_generator(input_query, config, thread_id), thread_id), media_type="text/event-stream")
return StreamingResponse(
thread_completion_callback(agent_inst.stream_generator(input_query, config, thread_id), thread_id),
media_type="text/event-stream",
)
except Exception as e:
with threads_global_kv as g_threads:
thread_inst, created_at, status = g_threads[thread_id]
g_threads[thread_id] = (thread_inst, created_at, "ready")
return f"An error occurred: {e}. This thread is now set as ready"


@register_microservice(
name="opea_service@comps-chat-agent",
endpoint="/v1/threads/{thread_id}/runs/cancel",
Expand All @@ -164,5 +193,6 @@ def cancel_run(thread_id):
g_threads[thread_id] = (thread_inst, created_at, "try_cancel")
return "submit cancel request"


if __name__ == "__main__":
opea_microservices["opea_service@comps-chat-agent"].start()
4 changes: 2 additions & 2 deletions comps/agent/langchain/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ docarray[full]
duckduckgo-search
fastapi
huggingface_hub==0.24.0
langgraph
langsmith
langchain==0.2.9
langchain-huggingface
langchain-openai
langchain_community==0.2.7
langchainhub==0.1.20
langgraph
langsmith
numpy

# used by cloud native
Expand Down
1 change: 1 addition & 0 deletions comps/agent/langchain/src/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


def instantiate_agent(args, strategy="react_langchain", with_memory=False):
if strategy == "react_langchain":
from .strategy.react import ReActAgentwithLangchain
Expand Down
18 changes: 11 additions & 7 deletions comps/agent/langchain/src/global_var.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import threading

class ThreadSafeDict(dict) :
def __init__(self, * p_arg, ** n_arg) :
dict.__init__(self, * p_arg, ** n_arg)

class ThreadSafeDict(dict):
def __init__(self, *p_arg, **n_arg):
dict.__init__(self, *p_arg, **n_arg)
self._lock = threading.Lock()

def __enter__(self) :
def __enter__(self):
self._lock.acquire()
return self

def __exit__(self, type, value, traceback) :
def __exit__(self, type, value, traceback):
self._lock.release()


assistants_global_kv = ThreadSafeDict()
threads_global_kv = ThreadSafeDict()
threads_global_kv = ThreadSafeDict()
5 changes: 3 additions & 2 deletions comps/agent/langchain/src/strategy/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from uuid import uuid4

from ..tools import get_tools_descriptions
from ..utils import setup_llm
from uuid import uuid4


class BaseAgent:
Expand All @@ -22,4 +23,4 @@ def execute(self, state: dict):
pass

def non_streaming_run(self, query, config):
raise NotImplementedError
raise NotImplementedError
6 changes: 3 additions & 3 deletions comps/agent/langchain/src/strategy/planexec/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.json import parse_partial_json
from langchain_huggingface import ChatHuggingFace
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver

from ...global_var import threads_global_kv
from ...utils import has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent
from .prompt import (
Expand All @@ -30,7 +31,6 @@
planner_prompt,
replanner_prompt,
)
from ...global_var import threads_global_kv

# Define protocol

Expand Down Expand Up @@ -273,4 +273,4 @@ async def stream_generator(self, query, config, thread_id=None):
yield f"{k}: {v}\n"

yield f"data: {repr(event)}\n\n"
yield "data: [DONE]\n\n"
yield "data: [DONE]\n\n"
2 changes: 1 addition & 1 deletion comps/agent/langchain/src/strategy/ragagent/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import MemorySaver

from ..base_agent import BaseAgent
from .prompt import DOC_GRADER_PROMPT, RAG_PROMPT
Expand Down
17 changes: 10 additions & 7 deletions comps/agent/langchain/src/strategy/react/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@

from langchain.agents import AgentExecutor
from langchain.agents import create_react_agent as create_react_langchain_agent
from langchain.memory import ChatMessageHistory
from langchain_core.messages import HumanMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent

from ...global_var import threads_global_kv
from ...utils import has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent

from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.memory import ChatMessageHistory
from langgraph.checkpoint.memory import MemorySaver
from ...global_var import threads_global_kv


class ReActAgentwithLangchain(BaseAgent):
def __init__(self, args, with_memory=False):
Expand All @@ -31,6 +31,7 @@ def __init__(self, args, with_memory=False):
agent=agent_chain, tools=self.tools_descriptions, verbose=True, handle_parsing_errors=True
)
self.memory = {}

def get_session_history(session_id):
if session_id in self.memory:
return self.memory[session_id]
Expand All @@ -46,7 +47,7 @@ def get_session_history(session_id):
input_messages_key="input",
history_messages_key="chat_history",
history_factory_config=[],
)
)

def prepare_initial_state(self, query):
return {"input": query}
Expand Down Expand Up @@ -92,7 +93,9 @@ def __init__(self, args, with_memory=False):
tools = self.tools_descriptions

if with_memory:
self.app = create_react_agent(self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE, checkpointer=MemorySaver())
self.app = create_react_agent(
self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE, checkpointer=MemorySaver()
)
else:
self.app = create_react_agent(self.llm, tools=tools, state_modifier=REACT_SYS_MESSAGE)

Expand Down
23 changes: 15 additions & 8 deletions comps/agent/langchain/src/thread.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
from uuid import uuid4
from datetime import datetime
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from collections import deque
from datetime import datetime
from uuid import uuid4

from .global_var import threads_global_kv


class ThreadMemory:
def __init__(self):
self.query_list = deque()
def add_query(self, query):

def add_query(self, query):
msg_id = f"msg_{uuid4()}"
created_at = int(datetime.now().timestamp())

self.query_list.append((query, msg_id, created_at))

return msg_id, created_at

def get_query(self):
query, _, _ = self.query_list.pop()
return query


async def thread_completion_callback(content, thread_id):
with threads_global_kv as g_threads:
thread_inst, created_at, _ = g_threads[thread_id]
Expand All @@ -31,6 +37,7 @@ async def thread_completion_callback(content, thread_id):
g_threads[thread_id] = (thread_inst, created_at, "ready")
yield chunk


def instantiate_thread_memory(args=None):
thread_id = f"thread_{uuid4()}"
return ThreadMemory(), thread_id
return ThreadMemory(), thread_id
Loading

0 comments on commit 71e2658

Please sign in to comment.