Skip to content

Commit

Permalink
feat: add translation chain
Browse files Browse the repository at this point in the history
  • Loading branch information
vegetablest authored and af su committed Nov 4, 2024
1 parent 23c50e7 commit bf852fe
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/tablegpt/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def create_tablegpt_graph(
*,
session_id: str | None = None,
workdir: Path | None = None,
model_type: str | None = None,
error_trace_cleanup: bool = False,
nlines: int | None = None,
vlm: BaseLanguageModel | None = None,
safety_llm: Runnable | None = None,
dataset_retriever: BaseRetriever | None = None,
normalize_llm: BaseLanguageModel | None = None,
locale: str | None = None,
checkpointer: BaseCheckpointSaver | None = None,
verbose: bool = False,
) -> CompiledStateGraph:
Expand All @@ -59,13 +59,13 @@ def create_tablegpt_graph(
pybox_manager (BasePyBoxManager): A python code sandbox delegator, used to execute the data analysis code generated by llm.
session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None.
workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None.
model_type (str | None, optional): Read the data header into different formats according to the model type. Defaults to None.
error_trace_cleanup (bool, optional): Flag to clean up error traces. Defaults to False.
nlines (int | None, optional): Number of lines to read for preview. Defaults to None.
vlm (BaseLanguageModel | None, optional): Optional vision language model for processing images. Defaults to None.
safety_llm (Runnable | None, optional): Model used for safety classification of inputs. Defaults to None.
dataset_retriever (BaseRetriever | None, optional): Component to retrieve datasets. Defaults to None.
normalize_llm (BaseLanguageModel | None, optional): Model for data normalization tasks. Defaults to None.
locate (str | None, optional): The locale of the user. Defaults to None.
checkpointer (BaseCheckpointSaver | None, optional): Component for saving checkpoints. Defaults to None.
verbose (bool, optional): Flag to enable verbose logging. Defaults to False.
Expand All @@ -75,11 +75,12 @@ def create_tablegpt_graph(
workflow = StateGraph(AgentState)
file_reading_graph = create_file_reading_workflow(
nlines=nlines,
llm=llm,
pybox_manager=pybox_manager,
workdir=workdir,
session_id=session_id,
model_type=model_type,
normalize_llm=normalize_llm,
locale=locale,
verbose=verbose,
)
data_analyze_graph = create_data_analyze_workflow(
Expand Down
25 changes: 23 additions & 2 deletions src/tablegpt/agent/file_reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_table_reformat_chain,
wrap_normalize_code,
)
from tablegpt.chains.translation import get_translation_chain
from tablegpt.errors import NoAttachmentsError
from tablegpt.tools import IPythonTool, markdown_console_template
from tablegpt.utils import get_raw_table_info
Expand Down Expand Up @@ -47,24 +48,26 @@ class AgentState(MessagesState):


def create_file_reading_workflow(
llm: BaseLanguageModel,
pybox_manager: BasePyBoxManager,
*,
workdir: Path | None = None,
session_id: str | None = None,
nlines: int | None = None,
model_type: str | None = None,
normalize_llm: BaseLanguageModel | None = None,
locale: str | None = None,
verbose: bool = False,
):
"""Create a workflow for reading and processing files using an agent-based approach.
Args:
llm (Runnable): The primary language model for processing user input.
pybox_manager (BasePyBoxManager): A Python code sandbox delegator.
workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None.
session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None.
nlines (int | None, optional): The number of lines to display from the dataset head. Defaults to 5 if not provided.
model_type (str | None, optional): Read the data header into different formats according to the model type. Defaults to None.
normalize_llm (BaseLanguageModel | None, optional): An optional language model used for data normalization. Defaults to None.
locate (str | None, optional): The locale of the user. Defaults to None.
verbose (bool, optional): Flag to enable verbose logging for debugging. Defaults to False.
Returns:
Expand All @@ -73,6 +76,15 @@ def create_file_reading_workflow(
if nlines is None:
nlines = 5

# Read the data header into different formats according to the model type.
model_type = None
if llm.metadata is not None:
model_type = llm.metadata.get("model_type")

translation_chain = None
if locale is not None:
translation_chain = get_translation_chain(llm=llm)

tools = [IPythonTool(pybox_manager=pybox_manager, cwd=workdir, session_id=session_id)]
tool_executor = ToolNode(tools)

Expand Down Expand Up @@ -122,6 +134,9 @@ async def get_df_info(state: AgentState) -> dict:
var_name = state["entry_message"].additional_kwargs.get("var_name", "df")

thought = f"我已经收到您的数据文件,我需要查看文件内容以对数据集有一个初步的了解。首先我会读取数据到 `{var_name}` 变量中,并通过 `{var_name}.info` 查看 NaN 情况和数据类型。" # noqa: RUF001
if translation_chain is not None:
thought = await translation_chain.ainvoke(input={"locale": locale, "input": thought})

read_df_code = f"""# Load the data into a DataFrame
{var_name} = read_df('{filename}')"""

Expand Down Expand Up @@ -175,6 +190,8 @@ def get_df_head(state: AgentState) -> dict:
var_name = state["entry_message"].additional_kwargs.get("var_name", "df")

thought = f"""接下来我将用 `{var_name}.head({nlines})` 来查看数据集的前 {nlines} 行。"""
if translation_chain is not None:
thought = translation_chain.invoke(input={"locale": locale, "input": thought})

# The input visible to the LLM can prevent it from blindly imitating the actions of our encoder.
default_tool_input = f"""# Show the first {nlines} rows to understand the structure
Expand Down Expand Up @@ -231,6 +248,10 @@ def get_final_answer(state: AgentState) -> dict:
raise NoAttachmentsError

text = f"我已经了解了数据集 {filename} 的基本信息。请问我可以帮您做些什么?" # noqa: RUF001

if translation_chain is not None:
text = translation_chain.invoke(input={"locale": locale, "input": text})

return {
"messages": [
AIMessage(
Expand Down
28 changes: 28 additions & 0 deletions src/tablegpt/chains/translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable

if TYPE_CHECKING:
from langchain_core.language_models import BaseLanguageModel


TRANSLATION_PROMPT = """You are a translation assistant. Translate user input directly into the primary language of the {locale} region without explanation."""


translation_prompt_template = ChatPromptTemplate.from_messages(
[
("system", TRANSLATION_PROMPT),
("human", "{input}"),
]
)

output_parser = StrOutputParser()


def get_translation_chain(llm: BaseLanguageModel) -> Runnable:
"""return the guard chain runnable."""
return translation_prompt_template | llm | output_parser

0 comments on commit bf852fe

Please sign in to comment.