From bf852feac39d245211cbbf135cfa621085e94ebe Mon Sep 17 00:00:00 2001 From: afsu Date: Mon, 4 Nov 2024 18:53:32 +0800 Subject: [PATCH] feat: add translation chain --- src/tablegpt/agent/__init__.py | 7 ++++--- src/tablegpt/agent/file_reading.py | 25 +++++++++++++++++++++++-- src/tablegpt/chains/translation.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 5 deletions(-) create mode 100644 src/tablegpt/chains/translation.py diff --git a/src/tablegpt/agent/__init__.py b/src/tablegpt/agent/__init__.py index 356b8da..2476d79 100644 --- a/src/tablegpt/agent/__init__.py +++ b/src/tablegpt/agent/__init__.py @@ -38,13 +38,13 @@ def create_tablegpt_graph( *, session_id: str | None = None, workdir: Path | None = None, - model_type: str | None = None, error_trace_cleanup: bool = False, nlines: int | None = None, vlm: BaseLanguageModel | None = None, safety_llm: Runnable | None = None, dataset_retriever: BaseRetriever | None = None, normalize_llm: BaseLanguageModel | None = None, + locale: str | None = None, checkpointer: BaseCheckpointSaver | None = None, verbose: bool = False, ) -> CompiledStateGraph: @@ -59,13 +59,13 @@ def create_tablegpt_graph( pybox_manager (BasePyBoxManager): A python code sandbox delegator, used to execute the data analysis code generated by llm. session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None. workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None. - model_type (str | None, optional): Read the data header into different formats according to the model type. Defaults to None. error_trace_cleanup (bool, optional): Flag to clean up error traces. Defaults to False. nlines (int | None, optional): Number of lines to read for preview. Defaults to None. vlm (BaseLanguageModel | None, optional): Optional vision language model for processing images. Defaults to None. safety_llm (Runnable | None, optional): Model used for safety classification of inputs. Defaults to None. dataset_retriever (BaseRetriever | None, optional): Component to retrieve datasets. Defaults to None. normalize_llm (BaseLanguageModel | None, optional): Model for data normalization tasks. Defaults to None. + locate (str | None, optional): The locale of the user. Defaults to None. checkpointer (BaseCheckpointSaver | None, optional): Component for saving checkpoints. Defaults to None. verbose (bool, optional): Flag to enable verbose logging. Defaults to False. @@ -75,11 +75,12 @@ def create_tablegpt_graph( workflow = StateGraph(AgentState) file_reading_graph = create_file_reading_workflow( nlines=nlines, + llm=llm, pybox_manager=pybox_manager, workdir=workdir, session_id=session_id, - model_type=model_type, normalize_llm=normalize_llm, + locale=locale, verbose=verbose, ) data_analyze_graph = create_data_analyze_workflow( diff --git a/src/tablegpt/agent/file_reading.py b/src/tablegpt/agent/file_reading.py index 6174f02..c5eef6c 100644 --- a/src/tablegpt/agent/file_reading.py +++ b/src/tablegpt/agent/file_reading.py @@ -16,6 +16,7 @@ get_table_reformat_chain, wrap_normalize_code, ) +from tablegpt.chains.translation import get_translation_chain from tablegpt.errors import NoAttachmentsError from tablegpt.tools import IPythonTool, markdown_console_template from tablegpt.utils import get_raw_table_info @@ -47,24 +48,26 @@ class AgentState(MessagesState): def create_file_reading_workflow( + llm: BaseLanguageModel, pybox_manager: BasePyBoxManager, *, workdir: Path | None = None, session_id: str | None = None, nlines: int | None = None, - model_type: str | None = None, normalize_llm: BaseLanguageModel | None = None, + locale: str | None = None, verbose: bool = False, ): """Create a workflow for reading and processing files using an agent-based approach. Args: + llm (Runnable): The primary language model for processing user input. pybox_manager (BasePyBoxManager): A Python code sandbox delegator. workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None. session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None. nlines (int | None, optional): The number of lines to display from the dataset head. Defaults to 5 if not provided. - model_type (str | None, optional): Read the data header into different formats according to the model type. Defaults to None. normalize_llm (BaseLanguageModel | None, optional): An optional language model used for data normalization. Defaults to None. + locate (str | None, optional): The locale of the user. Defaults to None. verbose (bool, optional): Flag to enable verbose logging for debugging. Defaults to False. Returns: @@ -73,6 +76,15 @@ def create_file_reading_workflow( if nlines is None: nlines = 5 + # Read the data header into different formats according to the model type. + model_type = None + if llm.metadata is not None: + model_type = llm.metadata.get("model_type") + + translation_chain = None + if locale is not None: + translation_chain = get_translation_chain(llm=llm) + tools = [IPythonTool(pybox_manager=pybox_manager, cwd=workdir, session_id=session_id)] tool_executor = ToolNode(tools) @@ -122,6 +134,9 @@ async def get_df_info(state: AgentState) -> dict: var_name = state["entry_message"].additional_kwargs.get("var_name", "df") thought = f"我已经收到您的数据文件,我需要查看文件内容以对数据集有一个初步的了解。首先我会读取数据到 `{var_name}` 变量中,并通过 `{var_name}.info` 查看 NaN 情况和数据类型。" # noqa: RUF001 + if translation_chain is not None: + thought = await translation_chain.ainvoke(input={"locale": locale, "input": thought}) + read_df_code = f"""# Load the data into a DataFrame {var_name} = read_df('{filename}')""" @@ -175,6 +190,8 @@ def get_df_head(state: AgentState) -> dict: var_name = state["entry_message"].additional_kwargs.get("var_name", "df") thought = f"""接下来我将用 `{var_name}.head({nlines})` 来查看数据集的前 {nlines} 行。""" + if translation_chain is not None: + thought = translation_chain.invoke(input={"locale": locale, "input": thought}) # The input visible to the LLM can prevent it from blindly imitating the actions of our encoder. default_tool_input = f"""# Show the first {nlines} rows to understand the structure @@ -231,6 +248,10 @@ def get_final_answer(state: AgentState) -> dict: raise NoAttachmentsError text = f"我已经了解了数据集 {filename} 的基本信息。请问我可以帮您做些什么?" # noqa: RUF001 + + if translation_chain is not None: + text = translation_chain.invoke(input={"locale": locale, "input": text}) + return { "messages": [ AIMessage( diff --git a/src/tablegpt/chains/translation.py b/src/tablegpt/chains/translation.py new file mode 100644 index 0000000..3dea3e9 --- /dev/null +++ b/src/tablegpt/chains/translation.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import Runnable + +if TYPE_CHECKING: + from langchain_core.language_models import BaseLanguageModel + + +TRANSLATION_PROMPT = """You are a translation assistant. Translate user input directly into the primary language of the {locale} region without explanation.""" + + +translation_prompt_template = ChatPromptTemplate.from_messages( + [ + ("system", TRANSLATION_PROMPT), + ("human", "{input}"), + ] +) + +output_parser = StrOutputParser() + + +def get_translation_chain(llm: BaseLanguageModel) -> Runnable: + """return the guard chain runnable.""" + return translation_prompt_template | llm | output_parser