Skip to content

Commit

Permalink
[Agent] support custom prompt (#798)
Browse files Browse the repository at this point in the history
* habana_main issue fixed, now use original dockerfile

Signed-off-by: Chendi.Xue <[email protected]>

* Enable custom_prompt

Signed-off-by: Chendi.Xue <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update README

Signed-off-by: Chendi.Xue <[email protected]>

---------

Signed-off-by: Chendi.Xue <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
xuechendi and pre-commit-ci[bot] authored Oct 22, 2024
1 parent 3a166c1 commit 3473bfb
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 38 deletions.
3 changes: 1 addition & 2 deletions comps/agent/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ export vllm_volume=${YOUR_LOCAL_DIR_FOR_MODELS}

# build vLLM image
git clone https://github.com/HabanaAI/vllm-fork.git
cd ./vllm-fork; git checkout habana_main; git tag v0.6.2.post1;
cp ${your_path}/GenAIComps/tests/agent/Dockerfile.hpu ./
cd ./vllm-fork
docker build -f Dockerfile.hpu -t opea/vllm:hpu --shm-size=128g . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy

# TGI serving
Expand Down
17 changes: 12 additions & 5 deletions comps/agent/langchain/src/agent.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from .utils import load_python_prompt


def instantiate_agent(args, strategy="react_langchain", with_memory=False):
if args.custom_prompt is not None:
print(f">>>>>> custom_prompt enabled, {args.custom_prompt}")
custom_prompt = load_python_prompt(args.custom_prompt)
else:
custom_prompt = None

if strategy == "react_langchain":
from .strategy.react import ReActAgentwithLangchain

return ReActAgentwithLangchain(args, with_memory)
return ReActAgentwithLangchain(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "react_langgraph":
from .strategy.react import ReActAgentwithLanggraph

return ReActAgentwithLanggraph(args, with_memory)
return ReActAgentwithLanggraph(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "react_llama":
print("Initializing ReAct Agent with LLAMA")
from .strategy.react import ReActAgentLlama

return ReActAgentLlama(args, with_memory)
return ReActAgentLlama(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "plan_execute":
from .strategy.planexec import PlanExecuteAgentWithLangGraph

return PlanExecuteAgentWithLangGraph(args, with_memory)
return PlanExecuteAgentWithLangGraph(args, with_memory, custom_prompt=custom_prompt)

elif strategy == "rag_agent" or strategy == "rag_agent_llama":
print("Initializing RAG Agent")
from .strategy.ragagent import RAGAgent

return RAGAgent(args, with_memory)
return RAGAgent(args, with_memory, custom_prompt=custom_prompt)
else:
raise ValueError(f"Agent strategy: {strategy} not supported!")
3 changes: 3 additions & 0 deletions comps/agent/langchain/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@

if os.environ.get("return_full_text") is not None:
env_config += ["--return_full_text", os.environ["return_full_text"]]

if os.environ.get("custom_prompt") is not None:
env_config += ["--custom_prompt", os.environ["custom_prompt"]]
5 changes: 3 additions & 2 deletions comps/agent/langchain/src/strategy/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
from uuid import uuid4

from ..tools import get_tools_descriptions
from ..utils import setup_llm
from ..utils import adapt_custom_prompt, setup_llm


class BaseAgent:
def __init__(self, args) -> None:
def __init__(self, args, local_vars=None, **kwargs) -> None:
self.llm_endpoint = setup_llm(args)
self.tools_descriptions = get_tools_descriptions(args.tools)
self.app = None
self.memory = None
self.id = f"assistant_{self.__class__.__name__}_{uuid4()}"
self.args = args
adapt_custom_prompt(local_vars, kwargs.get("custom_prompt"))
print(self.tools_descriptions)

@property
Expand Down
4 changes: 2 additions & 2 deletions comps/agent/langchain/src/strategy/planexec/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ def __call__(self, state):


class PlanExecuteAgentWithLangGraph(BaseAgent):
def __init__(self, args, with_memory=False):
super().__init__(args)
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)

# Define Node
plan_checker = PlanStepChecker(self.llm_endpoint, args.model, is_vllm=self.is_vllm)
Expand Down
4 changes: 2 additions & 2 deletions comps/agent/langchain/src/strategy/ragagent/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def __call__(self, state):


class RAGAgent(BaseAgent):
def __init__(self, args, with_memory=False):
super().__init__(args)
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)

# Define Nodes

Expand Down
13 changes: 7 additions & 6 deletions comps/agent/langchain/src/strategy/react/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@


class ReActAgentwithLangchain(BaseAgent):
def __init__(self, args, with_memory=False):
super().__init__(args)
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
prompt = hwchase17_react_prompt
if has_multi_tool_inputs(self.tools_descriptions):
raise ValueError("Only supports single input tools when using strategy == react_langchain")
Expand Down Expand Up @@ -81,12 +81,13 @@ async def stream_generator(self, query, config, thread_id=None):


class ReActAgentwithLanggraph(BaseAgent):
def __init__(self, args, with_memory=False):
super().__init__(args)
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)

self.llm = wrap_chat(self.llm_endpoint, args.model)

tools = self.tools_descriptions
print("REACT_SYS_MESSAGE: ", REACT_SYS_MESSAGE)

if with_memory:
self.app = create_react_agent(
Expand Down Expand Up @@ -207,8 +208,8 @@ def __call__(self, state):


class ReActAgentLlama(BaseAgent):
def __init__(self, args, with_memory=False):
super().__init__(args)
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
agent = ReActAgentNodeLlama(
llm_endpoint=self.llm_endpoint, model_id=args.model, tools=self.tools_descriptions, args=args
)
Expand Down
19 changes: 19 additions & 0 deletions comps/agent/langchain/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import argparse
import importlib

from .config import env_config

Expand Down Expand Up @@ -122,6 +123,23 @@ def has_multi_tool_inputs(tools):
return ret


def load_python_prompt(file_dir_path: str):
print(file_dir_path)
spec = importlib.util.spec_from_file_location("custom_prompt", file_dir_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def adapt_custom_prompt(local_vars, custom_prompt):
# list attributes of module
if custom_prompt is not None:
custom_prompt_list = [k for k in dir(custom_prompt) if k[:2] != "__"]
for k in custom_prompt_list:
v = getattr(custom_prompt, k)
local_vars[k] = v


def get_args():
parser = argparse.ArgumentParser()
# llm args
Expand All @@ -144,6 +162,7 @@ def get_args():
parser.add_argument("--temperature", type=float, default=0.01)
parser.add_argument("--repetition_penalty", type=float, default=1.03)
parser.add_argument("--return_full_text", type=bool, default=False)
parser.add_argument("--custom_prompt", type=str, default=None)

sys_args, unknown_args = parser.parse_known_args()
# print("env_config: ", env_config)
Expand Down
11 changes: 11 additions & 0 deletions comps/agent/langchain/tools/custom_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

REACT_SYS_MESSAGE = """\
Custom_prmpt !!!!!!!!!! Decompose the user request into a series of simple tasks when necessary and solve the problem step by step.
When you cannot get the answer at first, do not give up. Reflect on the info you have from the tools and try to solve the problem in a different way.
Please follow these guidelines when formulating your answer:
1. If the question contains a false premise or assumption, answer “invalid question”.
2. If you are uncertain or do not know the answer, respond with “I don’t know”.
3. Give concise, factual and relevant answers.
"""
17 changes: 0 additions & 17 deletions tests/agent/Dockerfile.hpu

This file was deleted.

2 changes: 2 additions & 0 deletions tests/agent/react_vllm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ services:
container_name: test-comps-agent-endpoint
volumes:
- ${TOOLSET_PATH}:/home/user/tools/
#- ${WORKPATH}/comps/agent:/home/user/comps/agent
ports:
- "9095:9095"
ipc: host
Expand All @@ -28,3 +29,4 @@ services:
http_proxy: ${http_proxy}
https_proxy: ${https_proxy}
port: 9095
custom_prompt: /home/user/tools/custom_prompt.py
4 changes: 2 additions & 2 deletions tests/agent/test_agent_langchain_on_intel_hpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ tgi_volume=$WORKPATH/data
vllm_port=8086
vllm_volume=$WORKPATH/data

export WORKPATH=$WORKPATH

export agent_image="opea/agent-langchain:comps"
export agent_container_name="test-comps-agent-endpoint"

Expand Down Expand Up @@ -46,8 +48,6 @@ function build_vllm_docker_images() {
echo $WORKPATH
if [ ! -d "./vllm" ]; then
git clone https://github.com/HabanaAI/vllm-fork.git
cd ./vllm-fork; git checkout habana_main; git tag v0.6.2.post1; cd ..
cp $WORKPATH/tests/agent/Dockerfile.hpu ./vllm-fork
fi
cd ./vllm-fork
docker build -f Dockerfile.hpu -t opea/vllm:hpu --shm-size=128g . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy
Expand Down

0 comments on commit 3473bfb

Please sign in to comment.