Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add translation chain #30

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/tablegpt/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def create_tablegpt_graph(
*,
session_id: str | None = None,
workdir: Path | None = None,
model_type: str | None = None,
error_trace_cleanup: bool = False,
nlines: int | None = None,
vlm: BaseLanguageModel | None = None,
safety_llm: Runnable | None = None,
dataset_retriever: BaseRetriever | None = None,
normalize_llm: BaseLanguageModel | None = None,
locale: str | None = None,
checkpointer: BaseCheckpointSaver | None = None,
verbose: bool = False,
) -> CompiledStateGraph:
Expand All @@ -59,13 +59,13 @@ def create_tablegpt_graph(
pybox_manager (BasePyBoxManager): A python code sandbox delegator, used to execute the data analysis code generated by llm.
session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None.
workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None.
model_type (str | None, optional): Read the data header into different formats according to the model type. Defaults to None.
error_trace_cleanup (bool, optional): Flag to clean up error traces. Defaults to False.
nlines (int | None, optional): Number of lines to read for preview. Defaults to None.
vlm (BaseLanguageModel | None, optional): Optional vision language model for processing images. Defaults to None.
safety_llm (Runnable | None, optional): Model used for safety classification of inputs. Defaults to None.
dataset_retriever (BaseRetriever | None, optional): Component to retrieve datasets. Defaults to None.
normalize_llm (BaseLanguageModel | None, optional): Model for data normalization tasks. Defaults to None.
locate (str | None, optional): The locale of the user. Defaults to None.
checkpointer (BaseCheckpointSaver | None, optional): Component for saving checkpoints. Defaults to None.
verbose (bool, optional): Flag to enable verbose logging. Defaults to False.
Expand All @@ -75,11 +75,12 @@ def create_tablegpt_graph(
workflow = StateGraph(AgentState)
file_reading_graph = create_file_reading_workflow(
nlines=nlines,
llm=llm,
pybox_manager=pybox_manager,
workdir=workdir,
session_id=session_id,
model_type=model_type,
normalize_llm=normalize_llm,
locale=locale,
verbose=verbose,
)
data_analyze_graph = create_data_analyze_workflow(
Expand Down
25 changes: 23 additions & 2 deletions src/tablegpt/agent/file_reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_table_reformat_chain,
wrap_normalize_code,
)
from tablegpt.chains.translation import get_translation_chain
from tablegpt.errors import NoAttachmentsError
from tablegpt.tools import IPythonTool, markdown_console_template
from tablegpt.utils import get_raw_table_info
Expand Down Expand Up @@ -47,24 +48,26 @@ class AgentState(MessagesState):


def create_file_reading_workflow(
llm: BaseLanguageModel,
pybox_manager: BasePyBoxManager,
*,
workdir: Path | None = None,
session_id: str | None = None,
nlines: int | None = None,
model_type: str | None = None,
normalize_llm: BaseLanguageModel | None = None,
locale: str | None = None,
verbose: bool = False,
):
"""Create a workflow for reading and processing files using an agent-based approach.
Args:
llm (Runnable): The primary language model for processing user input.
pybox_manager (BasePyBoxManager): A Python code sandbox delegator.
workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None.
session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None.
nlines (int | None, optional): The number of lines to display from the dataset head. Defaults to 5 if not provided.
model_type (str | None, optional): Read the data header into different formats according to the model type. Defaults to None.
normalize_llm (BaseLanguageModel | None, optional): An optional language model used for data normalization. Defaults to None.
locate (str | None, optional): The locale of the user. Defaults to None.
verbose (bool, optional): Flag to enable verbose logging for debugging. Defaults to False.
Returns:
Expand All @@ -73,6 +76,15 @@ def create_file_reading_workflow(
if nlines is None:
nlines = 5

# Read the data header into different formats according to the model type.
model_type = None
if llm.metadata is not None:
model_type = llm.metadata.get("model_type")

translation_chain = None
if locale is not None:
translation_chain = get_translation_chain(llm=llm)

tools = [IPythonTool(pybox_manager=pybox_manager, cwd=workdir, session_id=session_id)]
tool_executor = ToolNode(tools)

Expand Down Expand Up @@ -122,6 +134,9 @@ async def get_df_info(state: AgentState) -> dict:
var_name = state["entry_message"].additional_kwargs.get("var_name", "df")

thought = f"我已经收到您的数据文件,我需要查看文件内容以对数据集有一个初步的了解。首先我会读取数据到 `{var_name}` 变量中,并通过 `{var_name}.info` 查看 NaN 情况和数据类型。" # noqa: RUF001
if translation_chain is not None:
thought = await translation_chain.ainvoke(input={"locale": locale, "input": thought})

read_df_code = f"""# Load the data into a DataFrame
{var_name} = read_df('{filename}')"""

Expand Down Expand Up @@ -175,6 +190,8 @@ def get_df_head(state: AgentState) -> dict:
var_name = state["entry_message"].additional_kwargs.get("var_name", "df")

thought = f"""接下来我将用 `{var_name}.head({nlines})` 来查看数据集的前 {nlines} 行。"""
if translation_chain is not None:
thought = translation_chain.invoke(input={"locale": locale, "input": thought})

# The input visible to the LLM can prevent it from blindly imitating the actions of our encoder.
default_tool_input = f"""# Show the first {nlines} rows to understand the structure
Expand Down Expand Up @@ -231,6 +248,10 @@ def get_final_answer(state: AgentState) -> dict:
raise NoAttachmentsError

text = f"我已经了解了数据集 {filename} 的基本信息。请问我可以帮您做些什么?" # noqa: RUF001

if translation_chain is not None:
text = translation_chain.invoke(input={"locale": locale, "input": text})

return {
"messages": [
AIMessage(
Expand Down
28 changes: 28 additions & 0 deletions src/tablegpt/chains/translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable

if TYPE_CHECKING:
from langchain_core.language_models import BaseLanguageModel


TRANSLATION_PROMPT = """You are a translation assistant. Translate user input directly into the primary language of the {locale} region without explanation."""


translation_prompt_template = ChatPromptTemplate.from_messages(
[
("system", TRANSLATION_PROMPT),
("human", "{input}"),
]
)

output_parser = StrOutputParser()


def get_translation_chain(llm: BaseLanguageModel) -> Runnable:
"""return the guard chain runnable."""
return translation_prompt_template | llm | output_parser