From c8ad59b688e4f8466360d430aa111d1c2bfcecf5 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Tue, 2 Jul 2024 13:27:28 +0530 Subject: [PATCH 1/4] feat: Add RetrySqlQueryCreatorTool for handling failed SQL query generation --- .../agent_toolkits/sqlcoder/toolkit.py | 2 + .../tools/sql_coder/tool.py | 141 +++++++++++++----- .../langchain/tools/sqlcoder/prompt.py | 16 +- 3 files changed, 120 insertions(+), 39 deletions(-) diff --git a/libs/community/langchain_community/agent_toolkits/sqlcoder/toolkit.py b/libs/community/langchain_community/agent_toolkits/sqlcoder/toolkit.py index 4cb97b002c40a..7a2924bdd5221 100644 --- a/libs/community/langchain_community/agent_toolkits/sqlcoder/toolkit.py +++ b/libs/community/langchain_community/agent_toolkits/sqlcoder/toolkit.py @@ -13,6 +13,7 @@ from langchain_community.tools.sql_coder.tool import ( QuerySparkSQLDataBaseTool, SqlQueryCreatorTool, + RetrySqlQueryCreatorTool ) class SQLCoderToolkit(BaseToolkit): @@ -54,6 +55,7 @@ def get_tools(self) -> List[BaseTool]: db=self.db, description=query_sql_database_tool_description ), QuerySQLCheckerTool(db=self.db, llm=self.llm), + RetrySqlQueryCreatorTool(sqlcreatorllm=self.sqlcreatorllm), SqlQueryCreatorTool( sqlcreatorllm=self.sqlcreatorllm , db=self.db, diff --git a/libs/community/langchain_community/tools/sql_coder/tool.py b/libs/community/langchain_community/tools/sql_coder/tool.py index 2430d482d520d..6622ec1dbc43e 100644 --- a/libs/community/langchain_community/tools/sql_coder/tool.py +++ b/libs/community/langchain_community/tools/sql_coder/tool.py @@ -14,6 +14,7 @@ from langchain_core.tools import StateTool import re +ERROR = "" class BaseSQLDatabaseTool(BaseModel): """Base tool for interacting with a SQL database.""" @@ -43,7 +44,7 @@ class Config(StateTool.Config): description: str = """ Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. - If an error is returned, re-run the sql_db_query_creator tool to get the correct query. + If an error is returned, re-run the retry_sql_db_query_creator tool to get the correct query. """ def __init__(__pydantic_self__, **data: Any) -> None: @@ -65,6 +66,7 @@ def _run( ) executable_query = executable_query.strip('\"') executable_query = re.sub('\\n```', '',executable_query) + self.db.run_no_throw(executable_query) return self.db.run_no_throw(executable_query) async def _arun( @@ -75,14 +77,98 @@ async def _arun( raise NotImplementedError("QuerySparkSQLDataBaseTool does not support async") def _extract_sql_query(self): - for value in self.state: + for value in reversed(self.state): for key, input_string in value.items(): - if "sql_db_query_creator" in key: + if "tool='retry_sql_db_query_creator'" in key: + return input_string + elif "tool='sql_db_query_creator'" in key: return input_string return None + + + +class RetrySqlQueryCreatorTool(StateTool): + """Tool for re-creating SQL query.Use this to retry creation of sql query.""" + + name = "retry_sql_db_query_creator" + description = """ + This is a tool used to re-create sql query for user input based on the incorrect query generated and error returned from sql_db_query tool. + Input to this tool is user prompt, incorrect sql query and error message + Output is a sql query + After running this tool, you can run sql_db_query tool to get the result + """ + sqlcreatorllm: BaseLanguageModel = Field(exclude=True) + + + class Config(StateTool.Config): + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + extra = Extra.allow + + def __init__(__pydantic_self__, **data: Any) -> None: + """Initialize the tool.""" + super().__init__(**data) + + def _run( + self, + user_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the SQL query for the incorrect query.""" + return self._create_sql_query(user_input) + + async def _arun( + self, + table_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + raise NotImplementedError("SqlQueryCreatorTool does not support async") + def _create_sql_query(self,user_input): + + sql_query = self._extract_sql_query() + error_message = self._extract_error_message() + if sql_query is None: + return "This tool is not meant to be run directly. Start with a SQLQueryCreatorTool" + + prompt_input = PromptTemplate( + input_variables=["user_input","sql_query", "error_message"], + template=SQL_QUERY_CREATOR_RETRY + ) + query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) + sql_query = query_creator_chain.run( + ( + { + "sql_query": sql_query, + "error_message": error_message, + "user_input": user_input + } + ) + ) + sql_query = sql_query.replace("```","") + sql_query = sql_query.replace("sql","") + + return sql_query + + def _extract_sql_query(self): + for value in reversed(self.state): + for key, input_string in value.items(): + if "tool='retry_sql_db_query_creator'" in key: + return input_string + elif "tool='sql_db_query_creator'" in key: + return input_string + return None + def _extract_error_message(self): + for value in reversed(self.state): + for key, input_string in value.items(): + if "tool='sql_db_query'" in key: + if "Error" in input_string: + return input_string + return None + class SqlQueryCreatorTool(StateTool): """Tool for creating SQL query.Use this to create sql query.""" @@ -147,43 +233,24 @@ def _parse_data_model_context(self): def _create_sql_query(self,user_input): few_shot_examples = self._parse_few_shot_examples() - sql_query = self._extract_sql_query() db_schema = self._parse_db_schema() data_model_context = self._parse_data_model_context() - if sql_query is None: - prompt_input = PromptTemplate( - input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], - template=self.SQL_QUERY_CREATOR_TEMPLATE, - ) - query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) - - sql_query = query_creator_chain.run( - ( - { - "db_schema": db_schema, - "user_input": user_input, - "few_shot_examples": few_shot_examples, - "data_model_context": data_model_context - } - ) - ) - else: - prompt_input = PromptTemplate( - input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], - template=SQL_QUERY_CREATOR_RETRY - ) - query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) - - sql_query = query_creator_chain.run( - ( - { - "db_schema": db_schema, - "user_input": user_input, - "few_shot_examples": few_shot_examples, - "data_model_context": data_model_context - } - ) + prompt_input = PromptTemplate( + input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], + template=self.SQL_QUERY_CREATOR_TEMPLATE, + ) + query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) + + sql_query = query_creator_chain.run( + ( + { + "db_schema": db_schema, + "user_input": user_input, + "few_shot_examples": few_shot_examples, + "data_model_context": data_model_context + } ) + ) sql_query = sql_query.replace("```","") sql_query = sql_query.replace("sql","") diff --git a/libs/langchain/langchain/tools/sqlcoder/prompt.py b/libs/langchain/langchain/tools/sqlcoder/prompt.py index dffbf4f2e136a..af730b7e31fbe 100644 --- a/libs/langchain/langchain/tools/sqlcoder/prompt.py +++ b/libs/langchain/langchain/tools/sqlcoder/prompt.py @@ -1,8 +1,20 @@ SQL_QUERY_CREATOR_RETRY = """ -You have failed in the first attempt to generate correct sql query. Please try again to rewrite correct sql query. -""" +Your task is convert an incorrect query resulting from user question to a correct query which is databricks sql compatible. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float + +### Task: +Generate a correct SQL query that answers the question [QUESTION]`{user_input}`[/QUESTION]. +The query you will correct is: {sql_query} +The error message is: {error_message} + +### Response: +Based on your instructions, here is the SQL query I have generated +[SQL]""" SQL_QUERY_CREATOR_7b = """ ### Instructions: From e308c7804ae721c7be26d3096b2faa8884a2b723 Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Wed, 3 Jul 2024 14:56:53 +0530 Subject: [PATCH 2/4] chore: Bump langchain version to 0.2.12dev1 --- libs/langchain/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index be13232dae869..8c06059ea4861 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.2.11dev1" +version = "0.2.12dev1" description = "Building applications with LLMs through composability" authors = [] license = "MIT" From 852d2e71ead28870937394568d84d19918de3a1c Mon Sep 17 00:00:00 2001 From: Sushant Burnawal Date: Sun, 14 Jul 2024 14:52:27 +0530 Subject: [PATCH 3/4] chore: Update langchain version to 0.2.7dev1 and include columns in query execution --- .../tools/sql_coder/prompt.py | 17 +++++++++++++++-- .../langchain_community/tools/sql_coder/tool.py | 2 +- libs/community/pyproject.toml | 2 +- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/tools/sql_coder/prompt.py b/libs/community/langchain_community/tools/sql_coder/prompt.py index dffbf4f2e136a..d7e7026bbd159 100644 --- a/libs/community/langchain_community/tools/sql_coder/prompt.py +++ b/libs/community/langchain_community/tools/sql_coder/prompt.py @@ -1,8 +1,21 @@ SQL_QUERY_CREATOR_RETRY = """ -You have failed in the first attempt to generate correct sql query. Please try again to rewrite correct sql query. -""" +### Instructions: +Your task is convert an incorrect query resulting from user question to a correct query which is databricks sql compatible. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float + +### Task: +Generate a correct SQL query that answers the question [QUESTION]`{user_input}`[/QUESTION]. +The query you will correct is: {sql_query} +The error message is: {error_message} + +### Response: +Based on your instructions, here is the SQL query I have generated +[SQL]""" SQL_QUERY_CREATOR_7b = """ ### Instructions: diff --git a/libs/community/langchain_community/tools/sql_coder/tool.py b/libs/community/langchain_community/tools/sql_coder/tool.py index 6622ec1dbc43e..f78c6b70e4ac7 100644 --- a/libs/community/langchain_community/tools/sql_coder/tool.py +++ b/libs/community/langchain_community/tools/sql_coder/tool.py @@ -67,7 +67,7 @@ def _run( executable_query = executable_query.strip('\"') executable_query = re.sub('\\n```', '',executable_query) self.db.run_no_throw(executable_query) - return self.db.run_no_throw(executable_query) + return self.db.run_no_throw(executable_query, include_columns=True) async def _arun( self, diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 8c9338e1ac395..7375e662cef66 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-community" -version = "0.2.5dev1" +version = "0.2.7dev1" description = "Community contributed LangChain integrations." authors = [] license = "MIT" From 3147269cd7b89c52b99fb45cf1dd29d77a5f8fff Mon Sep 17 00:00:00 2001 From: Arunraja Date: Wed, 4 Sep 2024 18:29:31 +0530 Subject: [PATCH 4/4] chore: Update langchain version to 0.2.8dev1 and include columns in query execution --- .../langchain_community/tools/spark_unitycatalog/tool.py | 5 ++++- libs/community/langchain_community/tools/sql_coder/tool.py | 4 +++- libs/community/pyproject.toml | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/tools/spark_unitycatalog/tool.py b/libs/community/langchain_community/tools/spark_unitycatalog/tool.py index 0fefce555609f..c086990cf180c 100644 --- a/libs/community/langchain_community/tools/spark_unitycatalog/tool.py +++ b/libs/community/langchain_community/tools/spark_unitycatalog/tool.py @@ -355,7 +355,10 @@ def _run( else: executable_query = query.strip() executable_query = executable_query.strip('"') - return self.db.run_no_throw(executable_query) + executable_query = executable_query.rstrip(";") + if "LIMIT".lower() not in executable_query.lower(): + executable_query = f"{executable_query} LIMIT 50" + return self.db.run_no_throw(executable_query , include_columns= True) else: return "This tool is not meant to be run directly. Start with a ListUnityCatalogTablesTool" diff --git a/libs/community/langchain_community/tools/sql_coder/tool.py b/libs/community/langchain_community/tools/sql_coder/tool.py index f78c6b70e4ac7..7d381e3666850 100644 --- a/libs/community/langchain_community/tools/sql_coder/tool.py +++ b/libs/community/langchain_community/tools/sql_coder/tool.py @@ -66,7 +66,9 @@ def _run( ) executable_query = executable_query.strip('\"') executable_query = re.sub('\\n```', '',executable_query) - self.db.run_no_throw(executable_query) + executable_query = executable_query.rstrip(";") + if "LIMIT".lower() not in executable_query.lower(): + executable_query = f"{executable_query} LIMIT 50" return self.db.run_no_throw(executable_query, include_columns=True) async def _arun( diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 7375e662cef66..cf82263f41cea 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-community" -version = "0.2.7dev1" +version = "0.2.8dev1" description = "Community contributed LangChain integrations." authors = [] license = "MIT"