-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RAG agent and ReAct agent implemention for llama3.1 served by TGI…
…-gaudi (#722) * add ragagent and react agent for llama3.1 Signed-off-by: minmin-intel <[email protected]> * update ut Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update test Signed-off-by: minmin-intel <[email protected]> * update test Signed-off-by: minmin-intel <[email protected]> * debug ut Signed-off-by: minmin-intel <[email protected]> * update test Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update test and readme Signed-off-by: minmin-intel <[email protected]> * update ragagent llama docgrader Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: minmin-intel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
cc80c1b
commit e7fdf53
Showing
15 changed files
with
669 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import uuid | ||
|
||
from huggingface_hub import ChatCompletionOutputFunctionDefinition, ChatCompletionOutputToolCall | ||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage | ||
from langchain_core.messages.tool import ToolCall | ||
from langchain_core.output_parsers import BaseOutputParser | ||
|
||
|
||
class QueryWriterLlamaOutputParser(BaseOutputParser): | ||
def parse(self, text: str): | ||
print("raw output from llm: ", text) | ||
json_lines = text.split("\n") | ||
print("json_lines: ", json_lines) | ||
output = [] | ||
for line in json_lines: | ||
try: | ||
output.append(json.loads(line)) | ||
except Exception as e: | ||
print("Exception happened in output parsing: ", str(e)) | ||
if output: | ||
return output | ||
else: | ||
return None | ||
|
||
|
||
def convert_json_to_tool_call(json_str, tool): | ||
tool_name = tool.name | ||
tcid = str(uuid.uuid4()) | ||
add_kw_tc = { | ||
"tool_calls": [ | ||
ChatCompletionOutputToolCall( | ||
function=ChatCompletionOutputFunctionDefinition( | ||
arguments={"query": json_str["query"]}, name=tool_name, description=None | ||
), | ||
id=tcid, | ||
type="function", | ||
) | ||
] | ||
} | ||
tool_call = ToolCall(name=tool_name, args={"query": json_str["query"]}, id=tcid) | ||
return add_kw_tc, tool_call | ||
|
||
|
||
def assemble_history(messages): | ||
""" | ||
messages: AI (query writer), TOOL (retriever), HUMAN (Doc Grader), AI, TOOL, HUMAN, etc. | ||
""" | ||
query_history = "" | ||
n = 1 | ||
for m in messages[1:]: # exclude the first message | ||
if isinstance(m, AIMessage): | ||
# if there is tool call | ||
if hasattr(m, "tool_calls") and len(m.tool_calls) > 0: | ||
for tool_call in m.tool_calls: | ||
query = tool_call["args"]["query"] | ||
query_history += f"{n}. {query}\n" | ||
n += 1 | ||
return query_history | ||
|
||
|
||
def aggregate_docs(messages): | ||
""" | ||
messages: AI (query writer), TOOL (retriever), HUMAN (Doc Grader | ||
""" | ||
docs = [] | ||
context = "" | ||
for m in messages[::-1]: | ||
if isinstance(m, ToolMessage): | ||
docs.append(m.content) | ||
elif isinstance(m, AIMessage): | ||
break | ||
for doc in docs[::-1]: | ||
context = context + doc + "\n" | ||
return context |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.